In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import os, sys
import glob
import numpy as np
import torch
import pandas as pd
import random
import torch.nn as nn
import pickle
import torch.nn.functional as F

In [3]:
from mstcn_model import *
from utility.adaptive_data_loader import Breakfast, collate_fn_override
from utils import calculate_mof, dotdict

In [4]:
os.environ["WANDB_API_KEY"] = "992b3b1371ba79f48484cfca522b3786d7fa52c2"
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdipika_singhania[0m (use `wandb login --relogin` to force relogin)


True

In [8]:
seed = 42

# Ensure deterministic behavior
def set_seed():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
set_seed()

# Device configuration
os.environ['CUDA_VISIBLE_DEVICES']='5'
# os.environ['CUDA_LAUNCH_BLOCKING']='6'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
config = dotdict(
    epochs=500,
    num_class=48,
    batch_size=8,
    learning_rate=5e-4,
    weight_decay=0,
    dataset="Breakfast",
    architecture="unet-ensemble",
    features_file_name="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/",
    chunk_size=1,
    max_frames_per_video=1200,
    feature_size=2048,
    ground_truth_files_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/",
    label_id_csv="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv",
    gamma=0.1,
    step_size=500,
    split=3,
#     output_dir="/mnt/data/ar-datasets/dipika/breakfast/ms_tcn/data/breakfast/results/unsuper-finetune-split2-0.05-data-llr/",
    output_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split3/",
    project_name="breakfast-split-3",
    train_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split{}.bundle",
    test_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split{}.bundle",
    all_files="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt",
    cutoff=8,
    data_per = 0.2,
    budget=40,
    semi_supervised_split="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/semi_supervised/train.split{}_amt{}.bundle")

config.train_split_file = config.train_split_file.format(config.split)
config.semi_supervised_split = config.semi_supervised_split.format(config.split, config.data_per)
config.test_split_file = config.test_split_file.format(config.split)

if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)

print(config)

{'epochs': 500, 'num_class': 48, 'batch_size': 8, 'learning_rate': 0.0005, 'weight_decay': 0, 'dataset': 'Breakfast', 'architecture': 'unet-ensemble', 'features_file_name': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/', 'chunk_size': 1, 'max_frames_per_video': 1200, 'feature_size': 2048, 'ground_truth_files_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/', 'label_id_csv': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv', 'gamma': 0.1, 'step_size': 500, 'split': 3, 'output_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split3/', 'project_name': 'breakfast-split-3', 'train_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split3.bundle', 'test_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split3.bundle', 'all_files': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt', 'cutoff': 8, 'data_per': 0.2, 'budget': 40, 'semi_supe

In [10]:
df=pd.read_csv(config.label_id_csv)
label_id_to_label_name = {}
label_name_to_label_id_dict = {}
for i, ele in df.iterrows():
    label_id_to_label_name[ele.label_id] = ele.label_name
    label_name_to_label_id_dict[ele.label_name] = ele.label_id

In [11]:
traindataset = Breakfast(config, fold='train', fold_file_name=config.train_split_file)
testdataset = Breakfast(config, fold='test', fold_file_name=config.test_split_file)

Number of videos logged in train fold is 1279
Number of videos not found in train fold is 0
Number of videos logged in test fold is 433
Number of videos not found in test fold is 0


In [12]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
trainloader = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=config.batch_size, 
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)
testloader = torch.utils.data.DataLoader(dataset=testdataset,
                                          batch_size=config.batch_size, 
                                          shuffle=False,
                                          pin_memory=True, num_workers=4,
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

In [13]:
set_seed()
model = MultiStageModel(num_stages=4, num_layers=10, num_f_maps=64, dim=2048, num_classes=48).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Requires loaded_vidid_selected_frames, boundaries_dict
ce_criterion = nn.CrossEntropyLoss(ignore_index=-100)
mse_criterion = nn.MSELoss(reduction='none')

In [14]:
pseudo_labels_dir = "/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/length_segmentation_output/"
def get_single_random(output_p, video_ids):
    # Generate target for only timestamps. Do not generate pseudo labels at first 30 epochs.
    boundary_target_tensor = torch.ones((output_p.shape[0], output_p.shape[2]), dtype=torch.long, 
                                        device=output_p.device) * (-100)
    for iter_num, cur_vidid in enumerate(video_ids):
        pseudo_l = open(pseudo_labels_dir + cur_vidid + ".txt").read().split("\n")[0:-1]
        pseudo_l = [label_name_to_label_id_dict[ele] for ele in pseudo_l]
        abc = torch.tensor(pseudo_l).to(torch.long).to(boundary_target_tensor.device)
        frame_idx_tensor = torch.arange(0, len(pseudo_l), 1).to(device)
        boundary_target_tensor[iter_num, frame_idx_tensor] = abc

    return boundary_target_tensor

In [15]:
best_val_acc = 0
best_epoch = -1
for epoch in range(100):
    print("Starting Training")
    model.train()
    for i, item in enumerate(trainloader):
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
        optimizer.zero_grad()
        
        middle_pred, predictions = model(item_0, src_mask_mse)
        psuedo_l = get_single_random(predictions[-1], item[4])
        loss = 0
        for p in predictions:
            loss += ce_criterion(p, psuedo_l)
            loss += 0.15 * torch.mean(torch.clamp(mse_criterion(F.log_softmax(p[:, :, 1:], dim=1), 
                                                                F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
                                        max=16) * src_mask_mse[:, :, 1:])
            

        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            with torch.no_grad():
                pred = torch.argmax(predictions[-1], dim=1)
                correct = float(torch.sum((pred == item_2) * src_mask).item())
                total = float(torch.sum(src_mask).item())
                print(f"Training:: Epoch {epoch}, Iteration {i}, Current loss {loss.item()}" +
                      f" Accuracy {correct * 100.0 / total}")
    # Calculating Expectation Step
    model.eval()

    print("Calculating Validation Data Accuracy")
    correct = 0.0
    total = 0.0
    for i, item in enumerate(testloader):
        with torch.no_grad():
            item_0 = item[0].to(device)
            item_1 = item[1].to(device)
            item_2 = item[2].to(device)
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

            middle_pred, predictions = model(item_0, src_mask_mse)

            pred = torch.argmax(predictions[-1], dim=1)
            correct += float(torch.sum((pred == item_2) * src_mask).item())
            total += float(torch.sum(src_mask).item())
    val_acc = correct * 100.0 / total
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), config.output_dir + "ms-tcn-best-model.wt")
    torch.save(model.state_dict(), config.output_dir + "ms-tcn-last-model.wt")
    print(f"Validation:: Epoch {epoch}, Probability Accuracy {val_acc}")

Starting Training
Training:: Epoch 0, Iteration 0, Current loss 15.790573120117188 Accuracy 4.178496481915813
Training:: Epoch 0, Iteration 10, Current loss 15.370514869689941 Accuracy 3.826644122096645
Training:: Epoch 0, Iteration 20, Current loss 14.129740715026855 Accuracy 1.939795375243953
Training:: Epoch 0, Iteration 30, Current loss 12.482433319091797 Accuracy 11.194074428050893
Training:: Epoch 0, Iteration 40, Current loss 12.917685508728027 Accuracy 12.976514635806671
Training:: Epoch 0, Iteration 50, Current loss 12.395744323730469 Accuracy 16.094906133431174
Training:: Epoch 0, Iteration 60, Current loss 14.231739044189453 Accuracy 11.620149004781497
Training:: Epoch 0, Iteration 70, Current loss 11.462493896484375 Accuracy 8.897485493230175
Training:: Epoch 0, Iteration 80, Current loss 11.71151351928711 Accuracy 10.336245697643632
Training:: Epoch 0, Iteration 90, Current loss 11.299846649169922 Accuracy 9.973498804214337
Training:: Epoch 0, Iteration 100, Current loss 1

Training:: Epoch 5, Iteration 20, Current loss 7.353093147277832 Accuracy 38.530827768287004
Training:: Epoch 5, Iteration 30, Current loss 8.147224426269531 Accuracy 33.62985366721523
Training:: Epoch 5, Iteration 40, Current loss 6.804376602172852 Accuracy 35.63823071419553
Training:: Epoch 5, Iteration 50, Current loss 5.116456031799316 Accuracy 53.302702210899824
Training:: Epoch 5, Iteration 60, Current loss 6.010030269622803 Accuracy 50.816286388670335
Training:: Epoch 5, Iteration 70, Current loss 6.13594913482666 Accuracy 34.466871434839845
Training:: Epoch 5, Iteration 80, Current loss 10.460982322692871 Accuracy 15.978131475684494
Training:: Epoch 5, Iteration 90, Current loss 8.309371948242188 Accuracy 24.25101373961655
Training:: Epoch 5, Iteration 100, Current loss 5.860569000244141 Accuracy 54.52690061794886
Training:: Epoch 5, Iteration 110, Current loss 10.958028793334961 Accuracy 23.074071786001113
Training:: Epoch 5, Iteration 120, Current loss 9.959272384643555 Accur

Training:: Epoch 10, Iteration 50, Current loss 4.701625823974609 Accuracy 59.192935720980714
Training:: Epoch 10, Iteration 60, Current loss 4.9347004890441895 Accuracy 50.88930034225439
Training:: Epoch 10, Iteration 70, Current loss 5.10862398147583 Accuracy 63.822047065064986
Training:: Epoch 10, Iteration 80, Current loss 4.778704643249512 Accuracy 54.361636723819935
Training:: Epoch 10, Iteration 90, Current loss 4.647982597351074 Accuracy 51.502018522916174
Training:: Epoch 10, Iteration 100, Current loss 3.4182682037353516 Accuracy 64.79775730877053
Training:: Epoch 10, Iteration 110, Current loss 4.693260192871094 Accuracy 61.22673368412786
Training:: Epoch 10, Iteration 120, Current loss 4.50994348526001 Accuracy 54.75123263110713
Training:: Epoch 10, Iteration 130, Current loss 4.696784496307373 Accuracy 54.08753096614368
Training:: Epoch 10, Iteration 140, Current loss 5.332844257354736 Accuracy 54.73626574192371
Training:: Epoch 10, Iteration 150, Current loss 4.2814593315

Training:: Epoch 15, Iteration 70, Current loss 4.606410026550293 Accuracy 61.505944517833555
Training:: Epoch 15, Iteration 80, Current loss 4.158119201660156 Accuracy 57.05001969279244
Training:: Epoch 15, Iteration 90, Current loss 4.054626941680908 Accuracy 54.39648151164685
Training:: Epoch 15, Iteration 100, Current loss 5.626355171203613 Accuracy 60.22376983673935
Training:: Epoch 15, Iteration 110, Current loss 5.021288871765137 Accuracy 47.6540043337903
Training:: Epoch 15, Iteration 120, Current loss 4.291757106781006 Accuracy 67.80562436252367
Training:: Epoch 15, Iteration 130, Current loss 3.416718006134033 Accuracy 66.74364896073904
Training:: Epoch 15, Iteration 140, Current loss 3.8848893642425537 Accuracy 66.28718083706238
Training:: Epoch 15, Iteration 150, Current loss 4.288409233093262 Accuracy 58.89736823336424
Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 52.21042293290774
Starting Training
Training:: Epoch 16, Iteration 0, Curre

Training:: Epoch 20, Iteration 90, Current loss 2.972078323364258 Accuracy 65.28531134736386
Training:: Epoch 20, Iteration 100, Current loss 3.507822036743164 Accuracy 58.968756827616346
Training:: Epoch 20, Iteration 110, Current loss 5.561309337615967 Accuracy 49.16553411253058
Training:: Epoch 20, Iteration 120, Current loss 4.10400390625 Accuracy 56.3040595997222
Training:: Epoch 20, Iteration 130, Current loss 3.051084280014038 Accuracy 66.48866737866382
Training:: Epoch 20, Iteration 140, Current loss 3.465397596359253 Accuracy 56.016060095842505
Training:: Epoch 20, Iteration 150, Current loss 3.5183322429656982 Accuracy 70.77714285714286
Calculating Validation Data Accuracy
Validation:: Epoch 20, Probability Accuracy 55.739867523060695
Starting Training
Training:: Epoch 21, Iteration 0, Current loss 2.838994026184082 Accuracy 66.7671391008786
Training:: Epoch 21, Iteration 10, Current loss 3.005891799926758 Accuracy 61.51301900070373
Training:: Epoch 21, Iteration 20, Current 

Training:: Epoch 25, Iteration 110, Current loss 3.3176393508911133 Accuracy 64.81006025674614
Training:: Epoch 25, Iteration 120, Current loss 3.2464802265167236 Accuracy 65.22645847815625
Training:: Epoch 25, Iteration 130, Current loss 3.095716714859009 Accuracy 58.422350041084634
Training:: Epoch 25, Iteration 140, Current loss 3.0975496768951416 Accuracy 55.65167243367935
Training:: Epoch 25, Iteration 150, Current loss 3.310845375061035 Accuracy 73.32132564841498
Calculating Validation Data Accuracy
Validation:: Epoch 25, Probability Accuracy 54.321460259708374
Starting Training
Training:: Epoch 26, Iteration 0, Current loss 2.563094139099121 Accuracy 62.46279445531083
Training:: Epoch 26, Iteration 10, Current loss 6.8735833168029785 Accuracy 65.13337893296854
Training:: Epoch 26, Iteration 20, Current loss 3.7761824131011963 Accuracy 66.07632031598395
Training:: Epoch 26, Iteration 30, Current loss 2.9738712310791016 Accuracy 64.16936488169365
Training:: Epoch 26, Iteration 40,

Training:: Epoch 30, Iteration 120, Current loss 2.6697638034820557 Accuracy 68.34427885829756
Training:: Epoch 30, Iteration 130, Current loss 2.6172778606414795 Accuracy 60.78732463181406
Training:: Epoch 30, Iteration 140, Current loss 2.1673996448516846 Accuracy 60.907571667309426
Training:: Epoch 30, Iteration 150, Current loss 2.7019455432891846 Accuracy 59.94409503843466
Calculating Validation Data Accuracy
Validation:: Epoch 30, Probability Accuracy 53.86565923181232
Starting Training
Training:: Epoch 31, Iteration 0, Current loss 2.1089377403259277 Accuracy 57.627923542680755
Training:: Epoch 31, Iteration 10, Current loss 2.9002270698547363 Accuracy 51.24848668280872
Training:: Epoch 31, Iteration 20, Current loss 1.988592267036438 Accuracy 63.960880195599024
Training:: Epoch 31, Iteration 30, Current loss 2.537104606628418 Accuracy 58.05784194665587
Training:: Epoch 31, Iteration 40, Current loss 1.9774408340454102 Accuracy 57.73097826086956
Training:: Epoch 31, Iteration 50

Training:: Epoch 35, Iteration 130, Current loss 2.9510486125946045 Accuracy 58.249019202973365
Training:: Epoch 35, Iteration 140, Current loss 2.717771053314209 Accuracy 71.09184502844758
Training:: Epoch 35, Iteration 150, Current loss 1.8892467021942139 Accuracy 66.47031467989459
Calculating Validation Data Accuracy
Validation:: Epoch 35, Probability Accuracy 54.00539059491855
Starting Training
Training:: Epoch 36, Iteration 0, Current loss 1.5283524990081787 Accuracy 64.32868783794075
Training:: Epoch 36, Iteration 10, Current loss 1.928717017173767 Accuracy 65.0532684466658


KeyboardInterrupt: 

In [16]:
config.output_dir + "ms-tcn-best-model.wt"

'/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/ms-tcn-best-model.wt'

In [15]:
best_epoch

10

In [24]:
torch.save(model.state_dict(),
"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-speed/final-em-maximized.wt")

In [34]:
model.load_state_dict(torch.load(f"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split3/ms-tcn-initial-15-epochs.wt"))

<All keys matched successfully>

In [25]:
print("Calculating Validation Data Accuracy")
correct = 0.0
total = 0.0
for i, item in enumerate(testloader):
    with torch.no_grad():
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

        middle_pred, predictions = model(item_0, src_mask_mse)

        pred = torch.argmax(predictions[-1], dim=1)
        correct += float(torch.sum((pred == item_2) * src_mask).item())
        total += float(torch.sum(src_mask).item())

print(f"Validation:: Epoch {epoch}, Probability Accuracy {correct * 100.0 / total}")

Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 57.770101729343025


In [9]:
import pickle
# pickle.dump(video_id_boundary_frames, open("dump_dir/video_id_boundary_frames_dict.pkl", "wb"))
# pickle.dump(loaded_vidid_selected_frames, open("dump_dir/loaded_vidid_selected_frames_dict.pkl", "wb"))
pickle.dump(boundary_dict, open("dump_dir/chunk_1_video_id_boundary_frames_dict.pkl", "wb"))