In [1]:
import pandas as pd
import numpy as np
import pickle

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from PIL import Image

import torch
import torch.nn as nn
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision

from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tqdm.notebook import trange, tqdm

import random

from transformers import BertTokenizer, AutoTokenizer, AutoModel, BertModel

import matplotlib.pyplot as plt

from py_files.models import BertEncoder, ResnetPreTrained, ImageEncoder
from py_files.datasets import WSIBatchedDataset, GetRepsDataset

import faiss

import gc
import torch.optim as optim
import time

from torch.nn.parallel import DistributedDataParallel

In [2]:
def check_cuda():
    if torch.cuda.is_available():       
        device = torch.device("cuda")
        device_count = torch.cuda.device_count()
        print(f'There are {device_count} GPU(s) available.')
        for i in range(device_count):
            print('Device name:', torch.cuda.get_device_name(i))
        return device
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")
        return device
    
device=check_cuda()

There are 2 GPU(s) available.
Device name: Tesla K80
Device name: Tesla K80


In [3]:
df = pd.read_csv('../df.csv')
df.head()

Unnamed: 0,patch_paths,pid,svs_paths,dtype,notes
0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
3,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...
4,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...


## Train global KMeans clustering model

In [15]:
pids = df.pid.unique()
len(pids)

518

In [14]:
pid_percent_cluster=0.1
n_pids_cluster = int(pid_percent_cluster*len(pids))
n_pids_cluster

51

In [13]:
num_cluster=8

In [6]:
pids_cluster = np.random.choice(pids, size=n_pids_cluster)

normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

cluster_dataset = GetRepsDataset(df, pids_cluster, transform)
cluster_loader = torch.utils.data.DataLoader(cluster_dataset,batch_size=64, shuffle=True, \
                                             num_workers=1, pin_memory=True)
len(cluster_loader.dataset)

7677

In [7]:
print(torch.cuda.memory_allocated(device)*1e-6)

0.0


In [10]:
base_image_model = ResnetPreTrained()
image_encoder = nn.DataParallel(ImageEncoder(base_image_model))
image_encoder.to(device)
print(torch.cuda.memory_allocated(device)*1e-6)

94.279168


In [13]:
def train_global_cluster_model(encoder, dataloader, device, ncentroids=num_cluster):
    
    print("Getting patch representations...")
    
    rep_list = []
    path_list = []
    
    encoder.eval()
    
    for img, path in tqdm(dataloader):

        img=img.to(device)    

        in_batch_size = img.shape[0]
        
        with torch.no_grad():
            reps = encoder(img)
        rep_list.append(reps.detach().detach().cpu().numpy().reshape(in_batch_size, -1))
        path_list += path
        
        # clean up
        del img
        del reps
    
    print("\nTraining KMeans model...")
    
    X = np.concatenate(rep_list)
    X = np.ascontiguousarray(X)
    X = X.astype('float32')
    
    ncentroids = ncentroids
    niter = 300
    verbose = False
    d = X.shape[1]
    kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose, nredo=20)
    kmeans.train(X)
    
    print("\nFinished training KMeans model...")
    
    # clean up
    del encoder
    del dataloader
    torch.cuda.empty_cache()
    
    return kmeans

In [14]:
kmeans = train_global_cluster_model(image_encoder, cluster_loader, device, ncentroids=8)

Getting patch representations...


  0%|          | 0/105 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...

Training KMeans model...

Finished training KMeans model...


In [15]:
print(torch.cuda.memory_allocated(device)*1e-6)

94.279168


In [17]:
# # https://forums.fast.ai/t/gpu-memory-not-being-freed-after-training-is-over/10265/8?u=cedric
# def pretty_size(size):
# 	"""Pretty prints a torch.Size object"""
# 	assert(isinstance(size, torch.Size))
# 	return " × ".join(map(str, size))

# def dump_tensors(gpu_only=True):
# 	"""Prints a list of the Tensors being tracked by the garbage collector."""
# 	import gc
# 	total_size = 0
# 	for obj in gc.get_objects():
# 		try:
# 			if torch.is_tensor(obj):
# 				if not gpu_only or obj.is_cuda:
# 					print("%s:%s%s %s" % (type(obj).__name__, 
# 										  " GPU" if obj.is_cuda else "",
# 										  " pinned" if obj.is_pinned else "",
# 										  pretty_size(obj.size())))
# 					total_size += obj.numel()
# 			elif hasattr(obj, "data") and torch.is_tensor(obj.data):
# 				if not gpu_only or obj.is_cuda:
# 					print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
# 												   type(obj.data).__name__, 
# 												   " GPU" if obj.is_cuda else "",
# 												   " pinned" if obj.data.is_pinned else "",
# 												   " grad" if obj.requires_grad else "", 
# 												   " volatile" if obj.volatile else "",
# 												   pretty_size(obj.data.size())))
# 					total_size += obj.data.numel()
# 		except Exception as e:
# 			pass        
# 	print("Total size:", total_size*1e-6)

# dump_tensors()

## Cluster all patches in the dataset

In [26]:
gcluster_dataset = GetRepsDataset(df, pids, transform)
gcluster_loader = torch.utils.data.DataLoader(gcluster_dataset,batch_size=128, shuffle=False, \
                                             num_workers=1, pin_memory=True)
len(gcluster_loader.dataset)

69757

In [25]:
print(torch.cuda.memory_allocated(device)*1e-6)

133.33862399999998


In [27]:
def cluster_all_patches(encoder, kmeans, dataloader, device, ncentroids=8):
    
    print("Clustering all patches...")
    
    encoder.eval()
    
    path_list, rep_list = [], []
    
    for img, path in tqdm(dataloader):
        img=img.to(device)
        
        in_batch_size = img.shape[0]
        
        with torch.no_grad():
            reps = image_encoder(img)
        rep_list.append(reps.detach().detach().cpu().numpy().reshape(in_batch_size, -1))
        path_list += path
        
        # clean up
        del img
        del reps
        
    X = np.concatenate(rep_list)
    X = np.ascontiguousarray(X)
    X = X.astype('float32')
    
    D, I = kmeans.index.search(X, 1)
    
    df = pd.DataFrame(path_list, columns=['patch_paths'])
    df['cluster_assignment'] = I
    
    print("\nFinished clustering all patches...")
    
    # clean up
    del encoder
    # del dataloader
    torch.cuda.empty_cache()
    
    return df

In [28]:
cluster_df = cluster_all_patches(image_encoder, kmeans, gcluster_loader, device, ncentroids=8)

Clustering all patches...


  0%|          | 0/545 [00:00<?, ?it/s]

The current process just got forked. Disabling parallelism to avoid deadlocks...

Finished clustering all patches...


In [29]:
print(torch.cuda.memory_allocated(device)*1e-6)

133.33862399999998


In [30]:
df_clustered = df.merge(cluster_df, on='patch_paths')

df_clustered.to_csv('../df_clustered.csv', index=False)

In [20]:
# del image_encoder
# del cluster_loader
# del cluster_dataset
# del gcluster_loader
# del gcluster_dataset

# gc.collect()

# torch.cuda.empty_cache()

# print(torch.cuda.memory_allocated(device)*1e-6)

538.26816


### checkpoint - 1

In [3]:
df = pd.read_csv('../df_clustered.csv')
df.head()

Unnamed: 0,patch_paths,pid,svs_paths,dtype,notes,cluster_assignment
0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...,7
1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...,0
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...,0
3,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...,0
4,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,GTEX-R55E-1726,/project/GutIntelligenceLab/ss4yd/gtex_data/ac...,train,2 pieces ~9.5x7 mm; 1 broken apart; good morph...,0


In [4]:
df[df.pid=='GTEX-R55E-1726'].cluster_assignment.value_counts()

0    166
7     16
1     15
4     11
Name: cluster_assignment, dtype: int64

In [5]:
def global_loss(cnn_code, rnn_code, eps=1e-8, temp3=10.0):

    batch_size = cnn_code.shape[0]
    labels = Variable(torch.LongTensor(range(batch_size))).to(cnn_code.device)

    if cnn_code.dim() == 2 :
        cnn_code = cnn_code.unsqueeze(0)
        rnn_code = rnn_code.unsqueeze(0)
        
    cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
    rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)

    scores0 = torch.bmm(cnn_code, rnn_code.transpose(1,2))
    norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
    scores0 = scores0 / norm0.clamp(min=eps) * temp3
    
    # --> batch_size x batch_size
    if scores0.shape[0]!=1:
        scores0 = scores0.squeeze()
    else:
        scores0 = scores0.squeeze(0)
    

    scores1 = scores0.transpose(0, 1)
    loss0 = nn.CrossEntropyLoss()(scores0, labels)
    loss1 = nn.CrossEntropyLoss()(scores1, labels)
    return loss0, loss1

def cosine_similarity(x1, x2, dim=1, eps=1e-8):
    """Returns cosine similarity between x1 and x2, computed along dim."""
    w12 = torch.sum(x1 * x2, dim)
    w1 = torch.norm(x1, 2, dim)
    w2 = torch.norm(x2, 2, dim)
    return (w12 / (w1 * w2).clamp(min=eps)).squeeze()

def attention_fn(query, context, temp1):
    """
    query: batch x ndf x queryL
    context: batch x ndf x ih x iw (sourceL=ihxiw)
    mask: batch_size x sourceL
    """
    batch_size, queryL = query.size(0), query.size(2)
    ih, iw = context.size(2), context.size(3)
    sourceL = ih * iw

    # --> batch x sourceL x ndf
    context = context.view(batch_size, -1, sourceL)
    contextT = torch.transpose(context, 1, 2).contiguous()

    # Get attention
    # (batch x sourceL x ndf)(batch x ndf x queryL)
    # -->batch x sourceL x queryL
    attn = torch.bmm(contextT, query)
    # --> batch*sourceL x queryL
    attn = attn.view(batch_size * sourceL, queryL)
    attn = nn.Softmax(dim=-1)(attn)

    # --> batch x sourceL x queryL
    attn = attn.view(batch_size, sourceL, queryL)
    # --> batch*queryL x sourceL
    attn = torch.transpose(attn, 1, 2).contiguous()
    attn = attn.view(batch_size * queryL, sourceL)

    attn = attn * temp1
    attn = nn.Softmax(dim=-1)(attn)
    attn = attn.view(batch_size, queryL, sourceL)
    # --> batch x sourceL x queryL
    attnT = torch.transpose(attn, 1, 2).contiguous()

    # (batch x ndf x sourceL)(batch x sourceL x queryL)
    # --> batch x ndf x queryL
    weightedContext = torch.bmm(context, attnT)

    return weightedContext, attn.view(batch_size, -1, ih, iw)

def local_loss(
    img_features, words_emb, cap_lens, temp1=4.0, temp2=5.0, temp3=10.0, agg="sum"
):

    batch_size = img_features.shape[0]

    att_maps = []
    similarities = []
    # cap_lens = cap_lens.data.tolist()
    for i in range(words_emb.shape[0]):

        # Get the i-th text description
        words_num = cap_lens[i]  # 25
        # TODO: remove [SEP]
        # word = words_emb[i, :, 1:words_num+1].unsqueeze(0).contiguous()    # [1, 768, 25]
        word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()  # [1, 768, 25]
        word = word.repeat(batch_size, 1, 1)  # [48, 768, 25]
        context = img_features  # [48, 768, 19, 19]

        weiContext, attn = attention_fn(
            word, context, temp1
        )  # [48, 768, 25], [48, 25, 19, 19]

        att_maps.append(
            attn[i].unsqueeze(0).contiguous()
        )  # add attention for curr index  [25, 19, 19]
        word = word.transpose(1, 2).contiguous()  # [48, 25, 768]
        weiContext = weiContext.transpose(1, 2).contiguous()  # [48, 25, 768]

        word = word.view(batch_size * words_num, -1)  # [1200, 768]
        weiContext = weiContext.view(batch_size * words_num, -1)  # [1200, 768]

        row_sim = cosine_similarity(word, weiContext)
        row_sim = row_sim.view(batch_size, words_num)  # [48, 25]

        row_sim.mul_(temp2).exp_()
        if agg == "sum":
            row_sim = row_sim.sum(dim=1, keepdim=True)  # [48, 1]
        else:
            row_sim = row_sim.mean(dim=1, keepdim=True)  # [48, 1]
        row_sim = torch.log(row_sim)

        similarities.append(row_sim)

    similarities = torch.cat(similarities, 1)  #
    similarities = similarities * temp3
    similarities1 = similarities.transpose(0, 1)  # [48, 48]

    labels = Variable(torch.LongTensor(range(batch_size))).to(similarities.device)

    loss0 = nn.CrossEntropyLoss()(similarities, labels)  # labels: arange(batch_size)
    loss1 = nn.CrossEntropyLoss()(similarities1, labels)
    return loss0, loss1, att_maps

In [6]:
pid_batch_size = 10
num_cluster=8

normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
wsi_batch_dataset = WSIBatchedDataset(df, dtype='train', tokenizer=tokenizer, \
                                      img_transform=transform, pid_batch_size=pid_batch_size)
train_dloader = torch.utils.data.DataLoader(wsi_batch_dataset,batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

In [7]:
wsi_val_dataset = WSIBatchedDataset(df, dtype='val', tokenizer=tokenizer, \
                                      img_transform=transform, pid_batch_size=pid_batch_size)
val_dloader = torch.utils.data.DataLoader(wsi_val_dataset,batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

In [8]:
len(val_dloader)

10

In [9]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated(device)*1e-6)

0.0


In [10]:
base_image_model = nn.DataParallel(ResnetPreTrained())
base_image_model.to(device)
bert_model = nn.DataParallel(BertEncoder(device=device))
bert_model.to(device)

torch.cuda.empty_cache()
print(torch.cuda.memory_allocated(device)*1e-6)

538.26816


In [12]:
def get_bert_params(text_encoder):
    freeze_modules = [text_encoder.module.model.embeddings, *text_encoder.module.model.encoder.layer[:-4]]
    non_freeze_modules = [*text_encoder.module.model.encoder.layer[:-4]]
    
    param_list = []
    for module in freeze_modules:
        for param in module.parameters():
            param.requires_grad = False
            param_list.append(param)

    for module in non_freeze_modules:
        for param in module.parameters():
            param_list.append(param)
            
    return param_list

In [21]:
def start_pretraining(img_encoder, text_encoder, train_loader, val_loader, device, pid_batch_size=pid_batch_size, epochs=50):
    
    params = list(img_encoder.parameters()) + get_bert_params(text_encoder)
    optimizer = optim.Adadelta([param for param in params \
                                if param.requires_grad == True],lr=1e-3,rho=0.95)
    print("Start training...\n")
    print(f"{'Epoch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Elapsed':^9}")
    print("-"*60)
    
    best_val_loss=10
    best_img_model=0
    best_text_model=0
    epochs_since_improvement = 0
    
    for epoch_i in range(epochs):
        
        
        if epochs_since_improvement == 20:
            break

        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            print("\nDECAYING learning rate.")
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.8
            print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
        
        img_encoder.train()
        text_encoder.train()
        
        # Tracking time and loss
        t0_epoch = time.time()
        total_loss = 0
        
        for i, (img, text, attention, token_typ, img_seps, pids) in enumerate(train_loader):

            img_seps = [0]+[i.numpy()[0] for i in img_seps]
            img_seps = [[img_seps[i],img_seps[i+1]] for i in range(len(img_seps)-1)]
            
            img, text = img.squeeze(0).to(device),text.squeeze(0).to(device)
            attention, token_typ = attention.squeeze(0).to(device), token_typ.squeeze(0).to(device)
            
            text_outputs = text_encoder(text, attention, token_typ)
            img_outputs = img_encoder(img)

            # clean up
            del img, text, attention, token_typ
            gc.collect()

            cap_lens = text_outputs[2]
            cap_lens = [cap_lens[i].item() for i in np.arange(0, len(cap_lens), num_cluster)]

            pid_word_embeddings = [text_outputs[0][x:y] for x, y in img_seps]
            pid_sent_embeddings = [text_outputs[1][x:y] for x, y in img_seps]
            pid_img_embeddings = [img_outputs[x:y] for x, y in img_seps]

            cnn_code = torch.stack([x.mean(dim=0) for x in pid_img_embeddings])
            rnn_code = torch.stack([x[0] for x in pid_sent_embeddings])

            img_features = [x.unsqueeze(2).unsqueeze(2) for x in pid_img_embeddings]
            img_features = [x.permute(2,1,0,3) for x in img_features]
            img_features = torch.stack(img_features, dim=0).squeeze(1)

            words_embs = [x[0] for x in pid_word_embeddings]
            words_embs = torch.stack(words_embs)

            gloss0, gloss1 = global_loss(cnn_code, rnn_code)
            loss0, loss1, att_maps=local_loss(img_features, words_embs, cap_lens)
            loss = loss0 + loss1 + 0.1*gloss0 + 0.1*gloss1
            
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_loader)
        
        # =======================================
        #               Evaluation
        # =======================================
        if val_loader is not None:
            
            val_loss = evaluate(img_encoder, text_encoder, val_loader, device, test_dataloader=None)
            
            is_best = val_loss < best_val_loss
            best_val_loss = min(val_loss, best_val_loss)
            
            if is_best:
                best_img_model = img_encoder
                best_text_model = text_encoder
            
            if not is_best:
                epochs_since_improvement += 1
            else:
                epochs_since_improvement = 0
            
            
            
            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            print(f"{epoch_i + 1:^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {time_elapsed:^9.2f}")

    
    return best_img_model, best_text_model

def evaluate(img_encoder, text_encoder, val_dataloader, device, test_dataloader=None,plot=False):
    img_encoder.eval()
    text_encoder.eval()
    
    # Tracking variables
    val_accuracy = []
    val_loss = []

    # For each batch in our validation set..
    for img, text, attention, token_typ, img_seps, pids in val_dataloader:
        
        img_seps = [0]+[i.numpy()[0] for i in img_seps]
        img_seps = [[img_seps[i],img_seps[i+1]] for i in range(len(img_seps)-1)]

        img, text = img.squeeze(0).to(device),text.squeeze(0).to(device)
        attention, token_typ = attention.squeeze(0).to(device), token_typ.squeeze(0).to(device)
        
        with torch.no_grad():
            text_outputs = text_encoder(text, attention, token_typ)
            img_outputs = img_encoder(img)
        
        cap_lens = text_outputs[2]
        cap_lens = [cap_lens[i].item() for i in np.arange(0, len(cap_lens), num_cluster)]
        
        pid_word_embeddings = [text_outputs[0][x:y] for x, y in img_seps]
        pid_sent_embeddings = [text_outputs[1][x:y] for x, y in img_seps]
        pid_img_embeddings = [img_outputs[x:y] for x, y in img_seps]

        cnn_code = torch.stack([x.mean(dim=0) for x in pid_img_embeddings])
        rnn_code = torch.stack([x[0] for x in pid_sent_embeddings])

        img_features = [x.unsqueeze(2).unsqueeze(2) for x in pid_img_embeddings]
        img_features = [x.permute(2,1,0,3) for x in img_features]
        img_features = torch.stack(img_features, dim=0).squeeze(1)

        words_embs = [x[0] for x in pid_word_embeddings]
        words_embs = torch.stack(words_embs)

        gloss0, gloss1 = global_loss(cnn_code, rnn_code)
        loss0, loss1, att_maps=local_loss(img_features, words_embs, cap_lens)
        loss = loss0 + loss1 + 0.1*gloss0 + 0.1*gloss1

        val_loss.append(loss.item())
        
    val_loss = np.mean(val_loss)
    
    return val_loss
        

In [22]:
best_img_model, best_text_model = start_pretraining(base_image_model, bert_model, train_dloader,\
                                                    val_loader=val_dloader, device=device, \
                                                    pid_batch_size=pid_batch_size)

Start training...

 Epoch  |  Train Loss  |  Val Loss  |  Elapsed 
------------------------------------------------------------
   1    |   8.300355   |  8.542545  |   84.61  
   2    |   7.058738   |  7.807192  |   83.50  
   3    |   6.711402   |  7.614232  |   81.73  
   4    |   6.558301   |  7.455561  |   80.20  
   5    |   6.364923   |  7.319637  |   80.15  
   6    |   6.292632   |  7.201505  |   79.69  
   7    |   6.191592   |  7.183501  |   79.20  
   8    |   6.075982   |  7.145250  |   78.56  
   9    |   6.056406   |  7.279336  |   78.57  
  10    |   6.013784   |  7.336645  |   77.71  
  11    |   5.982965   |  7.313731  |   77.32  
  12    |   5.932807   |  7.461548  |   77.35  
  13    |   5.862575   |  7.313488  |   76.99  
  14    |   5.927340   |  7.182718  |   77.03  
  15    |   5.923119   |  7.149048  |   76.73  
  16    |   5.968211   |  7.361364  |   76.74  

DECAYING learning rate.
The new learning rate is 0.000800

  17    |   6.005700   |  7.469219  |   75.7

In [15]:
# torch.save(best_img_model.state_dict(), './best_img_model.pth')
# torch.save(best_text_model.state_dict(), './best_text_model.pth')

In [24]:
torch.save(best_img_model.module.state_dict(), './best_img_model_dp.pth')
torch.save(best_text_model.module.state_dict(), './best_text_model_dp.pth')

In [26]:
# best_img_model.module.state_dict()