<a href="https://colab.research.google.com/github/vinotharjun/LargeScaleImageMemorability/blob/master/ResMemNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from skimage import io,transform
import torch
from tqdm import tqdm
from torch import nn
import torch.nn.functional as F
import os
import time
from torch.optim import lr_scheduler
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms, utils,models
import copy
import math
from skimage import io, transform
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
!wget http://memorability.csail.mit.edu/lamem.tar.gz
!tar -xf /content/lamem.tar.gz

--2020-03-31 06:47:56--  http://memorability.csail.mit.edu/lamem.tar.gz
Resolving memorability.csail.mit.edu (memorability.csail.mit.edu)... 128.30.195.49
Connecting to memorability.csail.mit.edu (memorability.csail.mit.edu)|128.30.195.49|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2708368436 (2.5G) [application/x-gzip]
Saving to: ‘lamem.tar.gz’


2020-03-31 06:49:00 (40.7 MB/s) - ‘lamem.tar.gz’ saved [2708368436/2708368436]



In [0]:
dataset_train = pd.read_csv("/content/drive/My Drive/image memorability/dataset/train_dataset.csv")
dataset_validation = pd.read_csv("/content/drive/My Drive/image memorability/dataset/val.csv")
dataset_train = dataset_train[:1000]
dataset_validation = dataset_validation[:200]

In [0]:
class AsetheticsDataset(Dataset):
      '''asethitics dataset'''
      def __init__(self,dataframe,root_dir,transform=None):
        """
            Args:
                csv_file (string): Path to the csv file with annotations.
                root_dir (string): Directory with all the images.
                transform (callable, optional): Optional transform to be applied
                    on a sample.
        """
        self.data = dataframe
        #     self.data.rename(columns=columns,inplace=True)
    #     self.data.drop(self.data.columns[[1,2,3,4,5,6,8,9]] , axis=1,inplace=True)
        self.root_dir = root_dir
        self.transform = transform
    
      def __len__(self):
        return len(self.data)
  
      def __getitem__(self,idx):
   
        if torch.is_tensor(idx):
          idx = idx.tolist()

        image_name =  os.path.join(self.root_dir,self.data.iloc[idx,0])
        image = io.imread(image_name)
        mem_val = self.data.iloc[idx,1]
#     return_sample={}
        return_sample = {
              'image':image,
              'memorability_score':mem_val 
        }
        if self.transform:
            return_sample = self.transform(return_sample)
    
     
        return return_sample

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        
        image,mem_val = sample['image'], sample["memorability_score"]
        
        h, w = image.shape[:2]
        
        # if isinstance(self.output_size, int):
        #     if h > w:
        #         new_h, new_w = self.output_size * h / w, self.output_size
        #     else:
        #         new_h, new_w = self.output_size, self.output_size * w / h
        # else:
        #     new_h, new_w = self.output_size

        # new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (self.output_size,self.output_size,3))
        return {'image': img, 'memorability_score': mem_val}
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        
        image, mem_val = sample['image'], sample['memorability_score']
#         print(type(torch.from_numpy(image)))
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
#         print(image.shape)
      
        image = image.transpose((2, 0, 1))
        
        return {'image': torch.from_numpy(image),
                'memorability_score': mem_val}
class Normalize(object):
    def __init__(self,mean,std):
        self.mean=mean
        self.std=std
    def __call__(self,sample):
        image, mem_val = sample["image"], sample["memorability_score"]
        normalized=  (image -self.mean) / self.std
        return {
            "image":normalized,
            "memorability_score" : mem_val
       }

transformed_dataset_train = AsetheticsDataset(dataset_train,root_dir="/content/lamem/images",
                                        transform=transforms.Compose([Rescale(28),ToTensor(),Normalize(0.5,0.5)
                                                          ]))

transformed_dataset_val= AsetheticsDataset(dataset_validation,root_dir="/content/lamem/images",
                                        transform=transforms.Compose([Rescale(28),ToTensor(),Normalize(0.5,0.5)
                                                          ]))

train_dataloader=DataLoader(transformed_dataset_train,batch_size=32,shuffle=True)
validation_dataloader=DataLoader(transformed_dataset_val,batch_size=32,shuffle=True)
dataloaders={
    "train":train_dataloader,
    "val":validation_dataloader
}
dataset_sizes ={
    "train":len(dataset_train),
    "val":len(dataset_validation)
}

In [0]:
len(transformed_dataset_train)

1000

In [0]:
class ConvolutionLayer(nn.Module):
    '''
    Args:
    in_channels = the number of channels of input channel (type int)
    output_channels = the number of output feature maps (type int)
    kernel_size = the size of the kernel (type tuple)
    strides = the length of the strides (type int)
    padding = takes two 1) valid, 2) same or some other integer value [default:"valid"]
    
    '''
    def __init__(self,in_channels,out_channels,kernel_size=(3,3),strides=1,padding="valid"):
        super(ConvolutionLayer,self).__init__()
        
        self.in_channels = in_channels
        self.out_channels= out_channels
        self.kernel_size = kernel_size
        self.padding     = padding
        self.strides     = strides
        #layer definition
        self.convolution_layer = self.conv_layer(in_channels,out_channels,kernel_size,strides,padding)
        self.batch_normalize = nn.BatchNorm2d(out_channels)
        self.relu       = nn.ReLU()
        

        
    def forward(self,x):
        x = self.convolution_layer(x)
        x = self.relu(x)
        x = self.batch_normalize(x)
        return x
    
    def conv_layer(self,in_channels,out_channels,kernel_size,strides,padding):
        if padding == "valid":
            padding =0
        elif padding == "same":
            strides = 1
            padding = math.floor(int((kernel_size[0]-1)/2))  
        return nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding)

In [0]:
class LstmCell(nn.Module):
    '''
    Args:
    input_dims : the input dimension (that takes single integer) (type int)
    hidden dims : the dimension of hidden and cell state (type int)
    attach_fc  : to ensure whether fully connected layer should be connected to last rnn cell [default:False]
    '''
    def __init__(self,input_dims,hidden_dims,attach_fc = False):
        super(LstmCell,self).__init__()
        
        self.input_dims = input_dims
        self.hidden_dims = hidden_dims
        self.attach_fc = attach_fc
        
        self.lstm_cell = nn.LSTMCell(input_size = self.input_dims,hidden_size = self.hidden_dims)
        if self.attach_fc == True:
            self.fc = nn.Linear(self.hidden_dims,1)
        
    def forward(self,x,hidden_state,cell_state):
        hidden_output,cell_output = self.lstm_cell(x,(hidden_state,cell_state))
        if self.attach_fc ==True:
            output = self.fc(hidden_output)
            return output
        else:
            return hidden_output,cell_output
        
#     def init_hidden(self, batch_size):
#         hidden = torch.tensor(next(self.parameters()).data.new(batch_size, self.hidden_dims), requires_grad=False)
#         cell = torch.tensor(next(self.parameters()).data.new(batch_size, self.hidden_dims), requires_grad=False)
#         return hidden.zero_(), cell.zero_()

In [0]:
class VRNet(nn.Module):
    def __init__(self,in_channels,cnn_layer,lstm_cell):
        super(VRNet,self).__init__()
        self.hidden_dims =128
        self.input_dims =128
        self.stage_lstm =lstm_cell(self.input_dims,self.hidden_dims)
        #stage1
        self.stage1_cnn = cnn_layer(in_channels=3,out_channels=32,kernel_size=(3,3),padding="same",strides=1)
        self.stage1_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage1 inter
        self.stage1_inter_cnn3x3 = cnn_layer(in_channels=32,out_channels=64,kernel_size=(3,3),padding="valid",strides=1)
        self.stage1_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="valid",strides=1)
        self.stage1_interpool = nn.AdaptiveAvgPool2d(1)
       

        #stage2
        self.stage2_cnn = cnn_layer(in_channels=32,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage2_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage2 inter
        self.stage2_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage2_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage2_interpool = nn.AdaptiveAvgPool2d(1)

        #stage3
        self.stage3_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage3_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage3 inter
        self.stage3_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage3_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage3_interpool = nn.AdaptiveAvgPool2d(1)

        #stage4
        self.stage4_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage4_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage4 inter
        self.stage4_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage4_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage4_interpool = nn.AdaptiveAvgPool2d(1)

        #stage5
        self.stage5_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage5_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage5 inter
        self.stage5_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage5_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage5_interpool = nn.AdaptiveAvgPool2d(1)

        #stage6
        self.stage6_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage6_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage6 inter
        self.stage6_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage6_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage6_interpool = nn.AdaptiveAvgPool2d(1)

        #stage7
        self.stage7_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage7_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage6 inter
        self.stage7_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage7_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage7_interpool = nn.AdaptiveAvgPool2d(1)

        #stage8
        self.stage8_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        # self.stage8_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage6 inter
        self.stage8_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage8_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage8_interpool = nn.AdaptiveAvgPool2d(1)


        #stage9
        self.stage9_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage9_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage6 inter
        self.stage9_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage9_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage9_interpool = nn.AdaptiveAvgPool2d(1)


        #stage7
        self.stage10_cnn = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage10_pool = nn.AvgPool2d(kernel_size=(3,3),stride=2)
        #stage6 inter
        self.stage10_inter_cnn3x3 = cnn_layer(in_channels=64,out_channels=64,kernel_size=(3,3),padding="same",strides=1)
        self.stage10_inter_cnn1x1 = cnn_layer(in_channels=64,out_channels=128,kernel_size=(1,1),padding="same",strides=1)
        self.stage10_interpool = nn.AdaptiveAvgPool2d(1)



         #linear
        self.fc = nn.Linear(self.hidden_dims,1)
        
        
    def forward(self,x,hidden_state,cell_state):
        #stage1
        x = self.stage1_cnn(x)
        x = self.stage1_pool(x)
        
        #stage1 inter
        x1 = self.stage1_inter_cnn3x3(x)
        x1 = self.stage1_inter_cnn1x1(x1)
        x1  = self.stage1_interpool(x1)
        x1 = x1.squeeze()
        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1
        
        #stage2
        x = self.stage2_cnn(x)
        # x = self.stage2_pool(x)
      
        #stage2 inter
        x1 = self.stage2_inter_cnn3x3(x)
        x1 = self.stage2_inter_cnn1x1(x1)
        x1 = self.stage2_interpool(x1)
        x1 = x1.squeeze()
        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

        #stage3
        x = self.stage3_cnn(x)
        # x = self.stage3_pool(x)
        

        #stage4 inter
        x1 = self.stage3_inter_cnn3x3(x)
        x1 = self.stage3_inter_cnn1x1(x1)
        x1 = self.stage3_interpool(x1)
        x1 = x1.squeeze()
     
        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

        #stage4
        x = self.stage4_cnn(x)
        # x = self.stage4_pool(x)

        #stage4 inter
        x1 = self.stage4_inter_cnn3x3(x)
        x1 = self.stage4_inter_cnn1x1(x1)
        x1 = self.stage4_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

         #stage5
        x = self.stage5_cnn(x)
        # x = self.stage5_pool(x)

        #stage5 inter
        x1 = self.stage5_inter_cnn3x3(x)
        x1 = self.stage5_inter_cnn1x1(x1)
        x1 = self.stage5_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

         #stage6
        x = self.stage6_cnn(x)
        # x = self.stage6_pool(x)

        #stage6 inter
        x1 = self.stage6_inter_cnn3x3(x)
        x1 = self.stage6_inter_cnn1x1(x1)
        x1 = self.stage6_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

          #stage7
        x = self.stage7_cnn(x)
        # x = self.stage7_pool(x)

        #stage7 inter
        x1 = self.stage7_inter_cnn3x3(x)
        x1 = self.stage7_inter_cnn1x1(x1)
        x1 = self.stage7_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

          #stage8
        x = self.stage8_cnn(x)
        # x = self.stage8_pool(x)

        #stage8 inter
        x1 = self.stage8_inter_cnn3x3(x)
        x1 = self.stage8_inter_cnn1x1(x1)
        x1 = self.stage8_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

          #stage9
        x = self.stage9_cnn(x)
        x = self.stage9_pool(x)

        #stage7 inter
        x1 = self.stage9_inter_cnn3x3(x)
        x1 = self.stage9_inter_cnn1x1(x1)
        x1 = self.stage9_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1

          #stage10
        x = self.stage10_cnn(x)
        x = self.stage10_pool(x)

        #stage7 inter
        x1 = self.stage10_inter_cnn3x3(x)
        x1 = self.stage10_inter_cnn1x1(x1)
        x1 = self.stage10_interpool(x1)
        x1 = x1.squeeze()

        hidden_state,cell_state = self.stage_lstm(x1,hidden_state,cell_state)
        del x1


        return self.fc(hidden_state)
        
        
    def init_hidden(self, batch_size):
        hidden = torch.tensor(next(self.parameters()).data.new(batch_size, self.hidden_dims), requires_grad=False)
        cell = torch.tensor(next(self.parameters()).data.new(batch_size, self.hidden_dims), requires_grad=False)
        return hidden.zero_(), cell.zero_()

model = VRNet(3,ConvolutionLayer,LstmCell).to(device).double()
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
    if isinstance(m,nn.Linear):
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)
        
model.apply(weights_init)
print("")




In [0]:
# t = torch.rand([32,3,224,224]).to(device).double()

In [0]:
# hidden_state ,cell_state = model.init_hidden(32)

In [0]:

# with torch.no_grad():
#   x =model(t,hidden_state,cell_state)
#   print(x.shape)

In [0]:

def getMSE(d1, d2):
        mse = 0.0
        for a,b in zip(d1, d2):
            mse += (a-b)**2
        return mse / len(d1)

def getRankCorrelation(predicted, gt=None):

        if gt is None:
            return "needed gt"
        gt = np.array(gt).tolist()
        predicted = np.array(predicted).squeeze().tolist()

        n = min(len(predicted), len(gt))
        if n < 2:
            return 0

        gt = gt[:n]
        predicted = predicted[:n]
        mse = getMSE(gt, predicted)

        def get_rank(list_a):
            rank_list = np.zeros(len(list_a))
            idxs = np.array(list_a).argsort()
            for rank, i in enumerate(idxs):
                rank_list[i] = rank

            return rank_list

        gt_rank = get_rank(gt)
        predicted_rank = get_rank(predicted)
        ssd = 0
        for i in range(len(predicted_rank)):
            ssd += (gt_rank[i] -  predicted_rank[i])**2

        rc = 1-(6*ssd/(n*n*n - n))


        return rc, mse
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    since = time.time()
    running_loss_history = []
    val_running_loss_history=[]
    orignal_model=None
    best_model_wts = copy.deepcopy(model.state_dict())
    low_loss = np.inf
    stop = False


    for epoch in range(num_epochs):

       
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        if stop == True:
          break

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            rank_corr=[]

            # Iterate over data.
            for batched_data in tqdm(dataloaders[phase]):
                
                inputs=batched_data["image"]
                inputs = inputs.to(device)
                hidden_state ,cell_state = model.init_hidden(inputs.size(0))
                if inputs.size(0) == 16:
                   print(hidden_state.shape)
                labels=batched_data["memorability_score"]
                labels=labels.view(-1,1).double()
                labels = labels.to(device)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs,hidden_state,cell_state)
                    loss = criterion(outputs, labels)
                    if phase == "val":
                      rc,mse = getRankCorrelation(outputs.squeeze().tolist(),labels.squeeze().tolist())
                      rank_corr.append(rc)
                    print("  batch loss:    ",loss.item())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
               
                running_loss += loss.item() * inputs.size(0)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            if phase=="train":
              running_loss_history.append(epoch_loss)
            else:
              val_running_loss_history.append(epoch_loss)
            

            print('{} Loss: {:.4f}'.format(
                phase, epoch_loss))
            

            # deep copy the model
            if phase == "val":
                rho =sum(rank_corr)/len(rank_corr)
                print("rank correlation,final",rho)
                if rho >=0.67:
                  print("rank correlation final",rho)
                  stop = True
            
            if phase == 'val' and epoch_loss < low_loss:
                print("saving best model......")
                low_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts,"/content/drive/My Drive/image memorability/saved models/resnet50_weights_over.pt")

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    original_model =copy.deepcopy(model)
    model.load_state_dict(best_model_wts)
    return model,original_model,running_loss_history,val_running_loss_history

In [0]:
# f= open("/content/drive/My Drive/image memorability/saved models/resnet50_weights_over.pt","w+")

In [0]:
# sample =iter(train_dataloader).next()

In [0]:
# h,c=model.init_hidden(8)

In [0]:
# sample["memorability_score"].view(-1,1).squeeze()

In [0]:

        # rc, _ = stats.spearmanr(a=predicted, b=gt, axis=0)

In [0]:

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [0]:

best_model,_,_ ,_= train_model(best_model,criterion,optimizer,exp_lr_scheduler,100)




  0%|          | 0/32 [00:00<?, ?it/s][A[A[A

Epoch 0/99
----------





  3%|▎         | 1/32 [00:02<01:29,  2.90s/it][A[A[A

  batch loss:     0.006938466257007716





  6%|▋         | 2/32 [00:08<01:48,  3.61s/it][A[A[A

  batch loss:     0.013267911606045839





  9%|▉         | 3/32 [00:11<01:41,  3.49s/it][A[A[A

  batch loss:     0.01158581663754972





 12%|█▎        | 4/32 [00:15<01:42,  3.67s/it][A[A[A

  batch loss:     0.01839847821169457





 16%|█▌        | 5/32 [00:19<01:42,  3.80s/it][A[A[A

  batch loss:     0.010877960597763789





 19%|█▉        | 6/32 [00:22<01:33,  3.58s/it][A[A[A

  batch loss:     0.009538517723665145





 22%|██▏       | 7/32 [00:25<01:27,  3.49s/it][A[A[A

  batch loss:     0.01116875338793686





 25%|██▌       | 8/32 [00:28<01:20,  3.34s/it][A[A[A

  batch loss:     0.018898556518095272





 28%|██▊       | 9/32 [00:32<01:16,  3.31s/it][A[A[A

  batch loss:     0.00595325569823942





 31%|███▏      | 10/32 [00:35<01:09,  3.17s/it][A[A[A

  batch loss:     0.015677265411407692





 34%|███▍      | 11/32 [00:38<01:09,  3.32s/it][A[A[A

  batch loss:     0.012398704527942544





 38%|███▊      | 12/32 [00:41<01:04,  3.21s/it][A[A[A

  batch loss:     0.019986967339910563





 41%|████      | 13/32 [00:45<01:02,  3.30s/it][A[A[A

  batch loss:     0.017342121519889873





 44%|████▍     | 14/32 [00:48<00:59,  3.28s/it][A[A[A

  batch loss:     0.019678890767970324





 47%|████▋     | 15/32 [00:53<01:06,  3.93s/it][A[A[A

  batch loss:     0.018723270783526563





 50%|█████     | 16/32 [01:00<01:14,  4.68s/it][A[A[A

  batch loss:     0.0214878982039744





 53%|█████▎    | 17/32 [01:03<01:02,  4.18s/it][A[A[A

  batch loss:     0.01260626115572195





 56%|█████▋    | 18/32 [01:16<01:38,  7.01s/it][A[A[A

  batch loss:     0.011604055576214587





 59%|█████▉    | 19/32 [01:32<02:05,  9.67s/it][A[A[A

  batch loss:     0.01193473157490318





 62%|██████▎   | 20/32 [01:36<01:34,  7.84s/it][A[A[A

  batch loss:     0.01813946370435643





 66%|██████▌   | 21/32 [01:42<01:20,  7.35s/it][A[A[A

  batch loss:     0.02317679813263295





 69%|██████▉   | 22/32 [01:45<01:00,  6.01s/it][A[A[A

  batch loss:     0.011793862055841195





 72%|███████▏  | 23/32 [01:49<00:48,  5.42s/it][A[A[A

  batch loss:     0.014104092685826342





 75%|███████▌  | 24/32 [01:59<00:55,  6.92s/it][A[A[A

  batch loss:     0.017095069347702167





 78%|███████▊  | 25/32 [02:10<00:56,  8.04s/it][A[A[A

  batch loss:     0.01671033185248636





 81%|████████▏ | 26/32 [02:13<00:39,  6.50s/it][A[A[A

  batch loss:     0.01270776132177065





 84%|████████▍ | 27/32 [02:16<00:27,  5.56s/it][A[A[A

  batch loss:     0.014390193472367061





 88%|████████▊ | 28/32 [02:20<00:19,  4.88s/it][A[A[A

  batch loss:     0.01626180071205715





 91%|█████████ | 29/32 [02:23<00:13,  4.39s/it][A[A[A

  batch loss:     0.014829790301724353





 94%|█████████▍| 30/32 [02:26<00:08,  4.03s/it][A[A[A

  batch loss:     0.014292543745402862





 97%|█████████▋| 31/32 [02:31<00:04,  4.32s/it][A[A[A

  batch loss:     0.01556304804160457





100%|██████████| 32/32 [02:32<00:00,  4.76s/it]



  0%|          | 0/7 [00:00<?, ?it/s][A[A[A

  batch loss:     0.033509915924109276
train Loss: 0.0149





 14%|█▍        | 1/7 [00:03<00:20,  3.35s/it][A[A[A

  batch loss:     0.01034178177386198





 29%|██▊       | 2/7 [00:07<00:17,  3.56s/it][A[A[A

  batch loss:     0.007604582945936754





 43%|████▎     | 3/7 [00:12<00:16,  4.01s/it][A[A[A

  batch loss:     0.009662935197448309





 57%|█████▋    | 4/7 [00:15<00:11,  3.69s/it][A[A[A

  batch loss:     0.008676476085882182





 71%|███████▏  | 5/7 [00:21<00:08,  4.27s/it][A[A[A

  batch loss:     0.005366097618310501





 86%|████████▌ | 6/7 [00:26<00:04,  4.62s/it][A[A[A

  batch loss:     0.006766579961967743





100%|██████████| 7/7 [00:27<00:00,  3.87s/it]

  batch loss:     0.0027273751101957395
val Loss: 0.0079
rank correlation,final 0.6803668681548866
rank correlation final 0.6803668681548866
saving best model......

Epoch 1/99
----------
Training complete in 2m 59s





In [0]:
len(transformed_dataset_train)

100