Name - Rachit Yadav

Role - Fetch the data from mongo db and preprocess it to feed it to Neural Network

Algorithim - Neural Network



Name - Charudatta Manwatkar

Role - Develop the model and Hyperparameter tuning

Algorithim - Neural Network

This notebook connects mongo db to fetch embeddings stored on various collections. Each collection consists of embeddings stored generated using different pretrained models. Pretrained models used are-
1. vgg-11 (collection name - Images (embedding size - 1000), image_4 (embedding size - 4096))
2. Resnet-18 (collection name - image_2 (embedding size - 512))

In [0]:
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession

In [0]:
from pyspark.sql.types import *
from pyspark.sql.functions import *
from datetime import datetime
from pyspark.sql.window import Window

In [0]:
# Setting up configurations
sc1 = SparkSession.builder.appName('group3')\
    .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:2.4.0")\
    .config("spark.network.timeout", "36000000s")\
    .config("spark.executor.heartbeatInterval", "3600s")\
    .config("spark.mongodb.input.uri", "mongodb+srv://root:root@cluster0.eq07a.mongodb.net/Group3.image_4")\
    .config("spark.mongodb.output.uri", "mongodb+srv://root:root@cluster0.eq07a.mongodb.net/Group3.image_4")\
    .config("spark.databricks.io.cache.enabled", "true")\
    .config("spark.network.timeout", "7200s").getOrCreate()

In [0]:
sparkContext = sc1.sparkContext
sparkContext.setLogLevel('OFF')

Below mentioned code is used to connect to different collections present on Mongo-db. Just assign collection name to the collection variable

In [0]:
database = 'Group3'
collection = 'image_4'
user_name = 'root'
password = 'root'
address = 'cluster0.eq07a.mongodb.net'
connection_string = f"mongodb+srv://{user_name}:{password}@{address}/{database}.{collection}"

In [0]:
# connects to the database to fetch the data that consists of image name and it's corresponding embedding using collection_string
dff = spark.read.format("mongo").option("uri",connection_string).load().cache()

Below code connects to the mongo database to fetch train_csv file. This file has columns posting_id, image, image_phash, title, label_group. Among these we use image and label_group only.

In [0]:
database = 'Group3'
collection = 'train_csv'
user_name = 'root'
password = 'root'
address = 'cluster0.eq07a.mongodb.net'
connection_string = f"mongodb+srv://{user_name}:{password}@{address}/{database}.{collection}"

In [0]:
lkp=spark.read.format("mongo").option("uri",connection_string).load().select('image','label_group').cache()

We only choose those label groups that counts greater than 5.

In [0]:
rel_labels = lkp.groupby('label_group').count().filter(col('count')>5).select('label_group')

In [0]:
#lkp1 has all the image file name and label name whose count is greater than 5. Repartition the lkp1 data to 4 partitions as doing so imporoved the performance
lkp1 = lkp.join(rel_labels,'label_group').repartition(4)

Cross join lkp1 dataframe so that we can have combination of all the matching and non matching pairs

In [0]:
cross_j = lkp1.crossJoin(lkp1).sample(fraction=0.9,seed=3).cache()

In [0]:
cross_j.rdd.getNumPartitions()

In [0]:
newcolumns = [
 'label_code1',
'image1',
 
 'label_code2',
'image2',]

result_df_ren=cross_j.toDF(*newcolumns)

In [0]:
from pyspark.sql import functions as f

cresult_df_add=result_df_ren.withColumn('binary_label',f.when((result_df_ren.label_code1== result_df_ren.label_code2), '1').otherwise('0') ).select('image1','image2','binary_label').cache()

In [0]:
cresult_df_add.rdd.getNumPartitions()

Since after cross join, we have more matching label group pairs rather than non matching label groups. Therefore, below we downsample the negative data points

In [0]:
cresult_df_pos = cresult_df_add.filter(col('binary_label')==1)

In [0]:
cnt_pos = cresult_df_pos.count()

In [0]:
cresult_df_neg = cresult_df_add.filter(col('binary_label')==0)

In [0]:
cnt_neg = cresult_df_neg.count()

In [0]:
#random sample the negative datapoints using the ratio of positive count to the negative count
cresult_df_neg_red = cresult_df_neg.sample(withReplacement=False,fraction=cnt_pos/cnt_neg).limit(cnt_pos)

In [0]:
# combined the positive dataframe and negative dataframe
data = cresult_df_pos.union(cresult_df_neg_red).cache()

In [0]:
#extract file name from the path
dff_new = dff.rdd.map(lambda x:[x[0].split('/')[-1],x[1],x[2]]).toDF()

In [0]:
newcolumns = [ 'image1', 'array']

dff_new_wid_col=dff_new.select('_1','_2').toDF(*newcolumns).cache()

In [0]:
# this function fetches the embedding for given image file name
def conv(file):
    
    return dff_new_wid_col.filter(col('image1')==file).select('array').collect()

Modelling

In [0]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Initializing the model

In [0]:
%%time

# import torch
import numpy as np
from torch import nn

model = nn.Sequential(nn.Linear(4096*2, 4096),
                      nn.BatchNorm1d(4096),
                      nn.LeakyReLU(0.1),
                      nn.Linear(4096, 4096//2),
                      nn.BatchNorm1d(4096//2),
                      nn.ReLU(),
                      nn.Linear(4096//2,1)).to(device)


In [0]:
criterion = nn.BCEWithLogitsLoss()
# batch_size = 8
iteration= 20
total = data.count()



In [0]:
# train test split
s_data = data.randomSplit([0.8, 0.2])
train_data = s_data[0].cache()
test_data = s_data[1].cache()
train_loss_lst = []
test_loss_lst = []
train_accuracy_lst = []
test_accuracy_lst = []

Training the model

In [0]:
for batch_size in [500, 200]:
    for lr in [1e-5, 1e-3, 1e-1]:
        for wd in [0, 1e-4, 1e-2]:
            print(f'============= Batch size {batch_size}, Learning rate {lr}, Weight Decay {wd} =============')
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            for i in range(iteration): #tqdm(range(iteration)):
                abort_iter = False
                print(f'Iteration {i}', end=' ')
                try:
                    for df in [train_data, test_data]:
                        
                        # Sample the data without replacment from train/test set
                        batch = df.sample(fraction=batch_size/total)
                        imgs_nm1,imgs_nm2,lab = list(zip(*batch.rdd.map(lambda x : [x[0],x[1],x[2]]).collect()))
                        imgs1=[conv(x)[0][0] for x in imgs_nm1]
                        imgs2=[conv(x)[0][0] for x in imgs_nm2]

                        label = torch.tensor(list(map(int,lab)),dtype=torch.float32).reshape(-1,1).to(device)
                        tens1 = torch.tensor([eval(val) for val in imgs1],dtype=torch.float32)
                        tens2 = torch.tensor([eval(val) for val in imgs2],dtype=torch.float32)
                        ar = torch.hstack((tens1,tens2)).to(device)
                        if df==test_data:
                            model.eval()
                        else:
                            model.train()
                        y_hat = model(ar).to(device)
                        loss = criterion(y_hat, label)
                        y_thresh = y_hat > 0.5
                        accuracy = torch.mean((y_thresh == label).float())
                        if df==train_data:
                            train_loss_lst.append(loss.item())
                            train_accuracy_lst.append(accuracy)
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                            print(f'training loss = {loss.item():.4f}, training accuracy = {accuracy:.4f}', end=' ')
                        else:
                            test_loss_lst.append(loss.item())
                            test_accuracy_lst.append(accuracy)
                            print(f'val loss = {loss.item():.4f}, val accuracy = {accuracy:.4f}')
                    if i > 4:
                        if train_accuracy_lst[-4:] == sorted(train_accuracy_lst[-4:], reverse=True):
                            print('-------------------- Early Stopping Triggered --------------------')
                            break
                except ValueError:
                        print('---------- iteration aborted ----------')

                        pass




