In [None]:
import os
import pickle
import time
from collections import Counter

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm

import random
random.seed(42)

In [None]:
BATCH_SIZE = 8
N_EPOCHS = 50
LR = 1e-5
K=50
SAMPLES = 20000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def h_score(fx, gy):
    fx = fx - fx.mean(0)
    gy = gy - gy.mean(0)
    Nsamples = fx.size(0)
    covf = torch.matmul(fx.t(), fx) / Nsamples
    covg = torch.matmul(gy.t(), gy) / Nsamples
    h = -2 * torch.mean((fx * gy).sum(1)) + (covf * covg).sum()
    return h

In [None]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=8, stride=1)
        self.relu1 = nn.ReLU()

        # Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=5)

        # Fully connected 1 (readout)
        # self.fc1 = nn.Linear(5184, 15)
        self.fc1 = nn.Linear(1936, 640)
        self.fc2 = nn.Linear(640,12)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.cnn1(x)
        out = self.relu1(out)
        out = self.maxpool1(out)
        im_out = out
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.sigmoid(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return im_out, out

In [None]:
class CNNTrainer:
    def __init__(self) -> None:
        self.u_frames = None
        self.model = None

    def get_dataloader(self):
        print("Loading Data")

        with open('/content/drive/MyDrive/muri/dataset_list','rb') as f:
          u_frames = pickle.load(f)

        self.u_frames = u_frames

        data1 = []
        data2 = []
        for i in range(SAMPLES):
            data1.append(
                np.array(
                    # np.transpose(
                        np.asarray(u_frames[i % len(u_frames)][random.randint(0,K-1)]) #first chooses a datapoint and then randomly selects a frame from the (K=)50 frames associated with the label
                        # , (2, 0, 1)),
                    # dtype=np.float32,
                )
            )
            data2.append(
                np.array(
                    # np.transpose(                      
                        np.asarray(u_frames[(i + 1) % len(u_frames)][random.randint(0,K-1)])
                        # , (2, 0, 1)
                    # ),
                    # dtype=np.float32,
                )
            )

        y0 = np.asarray(data1)
        y1 = np.asarray(data2)

        x = torch.from_numpy(y0)
        y = torch.from_numpy(y1)
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        torch_dataset = TensorDataset(x, y)

        loader = DataLoader(
            dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
        )
        return loader

    def train_cnn(self):
        model = CNNModel()
        model.to(device=DEVICE)
        data_loader = self.get_dataloader()
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)

        print("Training Model")
        for epoch in range(N_EPOCHS):
            training_loss=0.0
            for x, y in data_loader:
                optimizer.zero_grad()
                x=x.unsqueeze(1)
                y=y.unsqueeze(1)
                # print(model(x.float())[1].squeeze(0).shape)
                loss = h_score(model(x.float())[1], model(y.float())[1])
                training_loss+= loss.item()
                loss.backward()
                optimizer.step()
            print(f"Epoch [{epoch + 1}/{N_EPOCHS}], loss:{training_loss:.4f}")
            
            if (epoch+1) % 2 == 0:

              print("Saving CNN Model")
              torch.save(model.state_dict(), f"new_cnn_model_{epoch+1}.pth")

        self.model = model
        return model

In [None]:
model = CNNTrainer().train_cnn()

Loading Data
Training Model
Epoch [1/50], loss:-0.0440
Epoch [2/50], loss:-0.0179
Saving CNN Model
Epoch [3/50], loss:-0.0043
Epoch [4/50], loss:0.0076
Saving CNN Model
Epoch [5/50], loss:0.0009
Epoch [6/50], loss:-0.0017
Saving CNN Model
Epoch [7/50], loss:-0.0007
Epoch [8/50], loss:0.0008
Saving CNN Model
Epoch [9/50], loss:-0.0013
Epoch [10/50], loss:-0.0002
Saving CNN Model
Epoch [11/50], loss:-0.0005
Epoch [12/50], loss:-0.0007
Saving CNN Model
Epoch [13/50], loss:0.0001
Epoch [14/50], loss:0.0002
Saving CNN Model
Epoch [15/50], loss:0.0001
Epoch [16/50], loss:-0.0009
Saving CNN Model
Epoch [17/50], loss:-0.0014
Epoch [18/50], loss:-0.0009
Saving CNN Model
Epoch [19/50], loss:-0.0002
Epoch [20/50], loss:0.0000
Saving CNN Model
Epoch [21/50], loss:0.0000
Epoch [22/50], loss:-0.0000
Saving CNN Model
Epoch [23/50], loss:-0.0000
Epoch [24/50], loss:-0.0002
Saving CNN Model
Epoch [25/50], loss:-0.0000
Epoch [26/50], loss:-0.0005
Saving CNN Model
Epoch [27/50], loss:-0.0001
Epoch [28/50

Sanity Check

In [None]:
z=0
for x, y in CNNTrainer().get_dataloader():
  # x=x.unsqueeze(0)
  x=x[0]
  x=x.unsqueeze(0)
  x=x.unsqueeze(1)
  print(x.shape) #input
  print(model(x.float())[1].shape) #output
  z+=1

  if z==1:
    break

Loading Data
torch.Size([1, 1, 64, 64])
torch.Size([1, 12])
