In [3]:
%load_ext autoreload
%autoreload 2

In [21]:
import wandb
import os, sys
import glob
import numpy as np
import torch
import pandas as pd
import random
import torch.nn as nn
import pickle
import torch.nn.functional as F
import matplotlib.pyplot as plt
import multiprocessing as mp
from time import time
from utils import get_all_scores

In [5]:
from mstcn_model import *
from utility.adaptive_data_loader import Breakfast, collate_fn_override
from utility.adaptive_data_loader import BreakfastWithWeights, collate_fn_override_wtd
from utils import calculate_mof, dotdict

In [6]:
os.environ["WANDB_API_KEY"] = "992b3b1371ba79f48484cfca522b3786d7fa52c2"
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdipika_singhania[0m (use `wandb login --relogin` to force relogin)


True

In [7]:
seed = 42

# Ensure deterministic behavior
def set_seed():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
set_seed()

# Device configuration
os.environ['CUDA_VISIBLE_DEVICES']='7'
# os.environ['CUDA_LAUNCH_BLOCKING']='6'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
config = dotdict(
    epochs=500,
    num_class=48,
    batch_size=8,
    learning_rate=5e-4,
    weight_decay=0,
    dataset="Breakfast",
    architecture="unet-ensemble",
    features_file_name="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/",
    chunk_size=1,
    max_frames_per_video=1200,
    feature_size=2048,
    ground_truth_files_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/",
    label_id_csv="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv",
    gamma=0.1,
    step_size=500,
    split=1,
#     output_dir="/mnt/data/ar-datasets/dipika/breakfast/ms_tcn/data/breakfast/results/unsuper-finetune-split2-0.05-data-llr/",
    output_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-random-select5/",
    project_name="breakfast-split-1",
    train_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split{}.bundle",
    test_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split{}.bundle",
    all_files="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt",
    cutoff=8,
    data_per = 0.2,
    budget=40,
    semi_supervised_split="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/semi_supervised/train.split{}_amt{}.bundle")

config.train_split_file = config.train_split_file.format(config.split)
config.semi_supervised_split = config.semi_supervised_split.format(config.split, config.data_per)
config.test_split_file = config.test_split_file.format(config.split)

if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)

config.output_dir = config.output_dir + f"split{config.split}"
if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)
config.output_dir = config.output_dir + "/"
if not os.path.exists(os.path.join(config.output_dir, "posterior_weights")):
    os.mkdir(os.path.join(config.output_dir, "posterior_weights"))
print(config)

{'epochs': 500, 'num_class': 48, 'batch_size': 8, 'learning_rate': 0.0005, 'weight_decay': 0, 'dataset': 'Breakfast', 'architecture': 'unet-ensemble', 'features_file_name': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/', 'chunk_size': 1, 'max_frames_per_video': 1200, 'feature_size': 2048, 'ground_truth_files_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/', 'label_id_csv': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv', 'gamma': 0.1, 'step_size': 500, 'split': 1, 'output_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-random-select5/split1/', 'project_name': 'breakfast-split-1', 'train_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split1.bundle', 'test_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split1.bundle', 'all_files': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt', 'cutoff': 8, 'data_per': 0.2, 'budget': 40, 'semi_supervised_split':

In [9]:
traindataset = BreakfastWithWeights(config, fold='train', fold_file_name=config.train_split_file)
testdataset = Breakfast(config, fold='test', fold_file_name=config.test_split_file)

Number of videos logged in train fold is 1460
Number of videos not found in train fold is 0
Number of videos logged in test fold is 252
Number of videos not found in test fold is 0


In [10]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
trainloader = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=config.batch_size, 
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override_wtd(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)
testloader = torch.utils.data.DataLoader(dataset=testdataset,
                                          batch_size=config.batch_size, 
                                          shuffle=False,
                                          pin_memory=True, num_workers=4,
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

trainloder_expectation = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=20,
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override_wtd(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

In [11]:
df = pd.read_csv(config.label_id_csv)
label_id_to_label_name = {}
label_name_to_label_id_dict = {}
for i, ele in df.iterrows():
    label_id_to_label_name[ele.label_id] = ele.label_name
    label_name_to_label_id_dict[ele.label_name] = ele.label_id

In [12]:
# selected_frames_dict = pickle.load(open("data/breakfast_len_assum_annotations.pkl", 'rb'))
# loaded_vidid_selected_frames
boundary_frames_dict = pickle.load(open("data/breakfast_boundary_annotations.pkl", "rb"))
num_boundary = 0
for key in boundary_frames_dict.keys():
    num_boundary += len(boundary_frames_dict[key])
# video_id_boundary_frames

In [13]:
selected_frames_dict = pickle.load(open("data/breakfast_random5frame_selection.pkl", "rb"))
# print(selected_frames_dict)

In [14]:
loaded_mean_var_actions = pickle.load(open("data/breakfast_meanvar_actions.pkl", "rb"))
mat_poisson = pickle.load(open("data/breakfast_possion_class_dict.pkl", "rb"))

def get_possion_prob(minlen, maxlen, cur_class):
    prob = mat_poisson[label_id_to_label_name[cur_class]][minlen:maxlen]
    return torch.tensor(prob)

def get_poisson_logcdf(minlen, cur_class):
    return np.log(np.sum(np.exp(mat_poisson[label_id_to_label_name[cur_class]][minlen:])) + 1e-20)

def get_possion_prob_for_all_class(minlen, maxlen):
    ele_list = []
    for i in range(config.num_class):
        prob = mat_poisson[label_id_to_label_name[i]][minlen:maxlen]
        ele_list.append(torch.tensor(prob))
    return torch.stack(ele_list, dim=-1)

In [22]:
def validate(model, dataloader, best_val_acc=None):
    model.eval()
    print("Calculating Validation Data Accuracy")
    correct = 0.0
    total = 0.0
    vidcount = 0
    all_scores = []
    for i, item in enumerate(testloader):
        with torch.no_grad():
            item_0 = item[0].to(device)
            item_1 = item[1].to(device)
            item_2 = item[2].to(device)
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
            middle_pred, predictions = model(item_0, src_mask_mse)
            pred = torch.argmax(predictions[-1], dim=1)
            correct += float(torch.sum((pred == item_2) * src_mask).item())
            total += float(torch.sum(src_mask).item())
            for p, l, c in zip(pred, item_2, item_1):
                all_scores.append(get_all_scores(p[:c].detach().cpu().numpy(), 
                                                 l[:c].detach().cpu().numpy(), ['SIL']))
            
    final_scores = np.mean(np.array(all_scores), axis=0)
    val_acc = correct * 100.0 / total
    if best_val_acc is not None and val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), config.output_dir + "ms-tcn-emmax-best-model.wt")
    torch.save(model.state_dict(), config.output_dir + "ms-tcn-emmax-last-model.wt")
    print(f"Validation:: Probability Accuracy {val_acc}")
    print(f"Other scores:: Edit {final_scores[3]}, F1@[10:25:50] {final_scores[:3]}")
    _ = model.train()
    return val_acc, best_val_acc

In [14]:
def prob_vals_per_segment(selected_frames, cur_vid_feat, labels, first_ele_flag, last_ele_flag, vidid, gt_labels):
    prob_each_segment = []
    LOW_VAL = -10000000
    num_frames = len(cur_vid_feat)
    log_probs = torch.log(cur_vid_feat + 1e-8)
    cumsum_feat = torch.cumsum(log_probs, dim=0)
    prev_boundary = 0
    per_frame_weights = torch.zeros((num_frames, config.num_class))
    start_time = time()
    boundary_error = 0
    current_boundary = 0
    labels = [config.num_class-1] + labels if selected_frames[0] != 0 else labels
    labels = labels + [config.num_class-1] if selected_frames[-1] != num_frames-1 else labels
    selected_frames = [0] + selected_frames if selected_frames[0] != 0 else selected_frames
    selected_frames = selected_frames + [num_frames-1] if selected_frames[-1] != num_frames-1 else selected_frames

    for i, cur_ele in enumerate(selected_frames[:-1]):
        next_ele = selected_frames[i + 1]
        label_cur_ele = labels[i]
        label_next_ele = labels[i + 1]
        if cur_ele == next_ele-1:
            per_frame_weights[cur_ele, label_cur_ele] = 1.0
            if label_cur_ele != label_next_ele:
                prev_boundary = cur_ele
            continue
        
        seg_len = next_ele - cur_ele
        mat_b1_b2_c_prob = LOW_VAL * torch.ones((seg_len, seg_len, config.num_class), dtype=cumsum_feat.dtype)
        b1_prior = get_possion_prob(cur_ele-prev_boundary, next_ele-prev_boundary, label_cur_ele)
        
        # find dummy label where we will keep the diagonal (b1=b2) probabilities, later we will distribute among
        # rest of the classes after the softmax by dividing by (num_class - 2)
        dummy_label = 0
        while True:
            if dummy_label != label_cur_ele and dummy_label != label_next_ele:
                break
            else:
                dummy_label += 1
        
        for b1 in range(cur_ele, next_ele - 1):

            cur_boundary_len = b1 - prev_boundary
            strt_index = cumsum_feat[cur_ele - 1, label_cur_ele] if cur_ele > 0 else 0
            left_sum = (cumsum_feat[b1, label_cur_ele] - strt_index)
            right_sum = cumsum_feat[next_ele-1, label_next_ele] - cumsum_feat[b1+1:next_ele, label_next_ele] # mid_seg_len
            mid_sum = (cumsum_feat[b1+1:next_ele, :] - cumsum_feat[b1, :])  # mid_seg_len
            b2_prior = get_possion_prob_for_all_class(1, next_ele-b1)  # mid_seg_len x num_class
            
            mat_b1_b2_c_prob[b1-cur_ele, b1+1-cur_ele:next_ele-cur_ele] = (left_sum + right_sum[:,None] + mid_sum) \
                                                                            + b1_prior[b1-cur_ele] + b2_prior
            # when mid segment is absent but right and left is not the same
            # we assign the probability to a dummy label for now and then later 
            # re-distribute among other classes after the softmax
            if label_cur_ele != label_next_ele:
                rightsum_wo_midseg = cumsum_feat[next_ele-1, label_next_ele] - cumsum_feat[b1, label_next_ele]
                mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, dummy_label] = left_sum + rightsum_wo_midseg + b1_prior[b1-cur_ele]
        
#         if vidid=='P39_cam02_P39_scrambledegg' and cur_ele==574:
#             import pdb
#             pdb.set_trace()
        # when mid segment is absent b1 can also be next_ele-1
        b1 = next_ele - 1
        if label_cur_ele != label_next_ele:
            left_sum = (cumsum_feat[b1, label_cur_ele] - strt_index)
            mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, dummy_label] = left_sum + b1_prior[b1-cur_ele]
        else:
            # returns prob that the left class length >= seg len
            b1_prior_ = get_poisson_logcdf(next_ele - prev_boundary, label_cur_ele) 
            mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, dummy_label] = left_sum + b1_prior_
        
        mat_b1_b2_c_prob[:, :, label_cur_ele] = LOW_VAL
        mat_b1_b2_c_prob[:, :, label_next_ele] = LOW_VAL
        mat_b1_b2_c_prob = torch.softmax(mat_b1_b2_c_prob.flatten(), dim=0).reshape((seg_len, seg_len, config.num_class))
        
        # re-distribute the dummy class probability among the left-over classes
        left_over_classes = config.num_class - 2 + (label_cur_ele==label_next_ele)
        for b1 in range(cur_ele, next_ele):
            assigned_prob = mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, dummy_label]
            mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, :] = assigned_prob/left_over_classes
            mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, label_cur_ele] = 0
            mat_b1_b2_c_prob[b1-cur_ele, b1-cur_ele, label_next_ele] = 0
        
        marginal_b1 = torch.sum(mat_b1_b2_c_prob, axis=(1,2))
        mean_b1 = round(torch.sum(marginal_b1.squeeze() * torch.arange(cur_ele, next_ele, 1)).item())
        cumm_b1_prob = torch.cumsum(marginal_b1, dim=0)
        cumm_b1_c_prob = torch.cumsum(torch.sum(mat_b1_b2_c_prob, dim=1), dim=0)
        cumm_b2_c_prob = torch.cumsum(torch.sum(mat_b1_b2_c_prob, dim=0), dim=0)

        per_frame_weights[cur_ele, label_cur_ele] = 1.0
        per_frame_weights[cur_ele+1:next_ele, :] = cumm_b1_c_prob[:-1] - cumm_b2_c_prob[:-1]
        per_frame_weights[cur_ele+1:next_ele, label_cur_ele] = 1 - cumm_b1_prob[:-1]
        per_frame_weights[cur_ele+1:next_ele, label_next_ele] = 0
        remaining_probability = 1 - torch.sum(per_frame_weights[cur_ele+1:next_ele, :], dim=-1)
        # we use "+=" in the next line because left and right label might be the same
        # in that case using "=" would just overwrite the previous probability
        per_frame_weights[cur_ele+1:next_ele, label_next_ele] += remaining_probability
        
        expected_boundary = round(torch.sum(torch.sum(mat_b1_b2_c_prob, axis=(0,2)).squeeze() * \
                            torch.arange(cur_ele, next_ele, 1)).item())
        if not (label_cur_ele == label_next_ele and expected_boundary >= next_ele-2):
            prev_boundary = expected_boundary
        if expected_boundary == 0 and i > 0:
            print(f'Estimated boundary has become zero! for {vidid} and cur_ele, next_ele {cur_ele, next_ele}')
            import pdb
            pdb.set_trace()
        # boundary_error += (boundary_frames_dict[vidid + '.txt'][current_boundary] - mean_b1)**2
        # boundary_error += (boundary_frames_dict[vidid + '.txt'][current_boundary+1] - prev_boundary)**2
        # current_boundary += 2
        # prob_each_segment.append(mat_b1_b2_c_prob)
        
    posterior_prediction = torch.argmax(per_frame_weights, dim=1)
    correct = torch.sum(posterior_prediction == gt_labels[:num_frames]).item()
    
    return (vidid, per_frame_weights, [correct, num_frames, boundary_error]) #, prob_each_segment)

In [15]:
posterior_acc_correct, posterior_acc_total = 0, 0
posterior_boundary_total_mse = 0
results = []

# Step 2: Define callback function to collect the output in `results`
def collect_result(result):
    global posterior_acc_correct, posterior_acc_total, posterior_boundary_total_mse
    fname = os.path.join(config.output_dir, 'posterior_weights', result[0] + '.wt')
    torch.save(result[1], fname)
    correct, total, boundary_err = result[2]
    posterior_acc_correct += correct
    posterior_acc_total += total
    posterior_boundary_total_mse += boundary_err
    # print(f'Dumped in file {fname} at time {time()}')
    return

def calculate_element_probb(data_feat, data_count, video_ids, gt_labels): # loaded_vidid_selected_frames, boundaries_dict):
    global posterior_acc_correct, posterior_acc_total, posterior_boundary_total_mse
    pool = mp.Pool(20)
    for iter_num in range(len(data_count)):
        cur_vidid = video_ids[iter_num]
#         if cur_vidid!='P39_cam02_P39_scrambledegg':
#             continue
        cur_vid_count = data_count[iter_num]
        cur_vid_feat = data_feat[iter_num][:cur_vid_count].detach().cpu()
        cur_gt_labels = gt_labels[iter_num].detach().cpu()
        
        cur_video_select_frames = selected_frames_dict[cur_vidid + ".txt"]
        selected_frames_indices_and_labels = cur_video_select_frames
        selected_frames_indices = [ele[0] for ele in selected_frames_indices_and_labels]
        selected_frames_labels = [label_name_to_label_id_dict[ele[1]] for ele in selected_frames_indices_and_labels]
        with torch.no_grad():
            # Multi-processing
            pool.apply_async(prob_vals_per_segment,
                             args=(selected_frames_indices, cur_vid_feat, selected_frames_labels,
                                   cur_video_select_frames[1], cur_video_select_frames[2], cur_vidid, cur_gt_labels),
                             callback=collect_result)
#             results.append(prob_vals_per_segment(selected_frames_indices, cur_vid_feat, selected_frames_labels,
#                                    cur_video_select_frames[1], cur_video_select_frames[2], cur_vidid, cur_gt_labels))
    # Step 4: Close Pool and let all the processes complete
    pool.close()
    pool.join()  # postpones the execution of next line of code until all processes in the queue are done.
    return results

def perform_expectation(model, dataloader):
    global posterior_acc_correct, posterior_acc_total, posterior_boundary_total_mse
    posterior_acc_correct, posterior_acc_total, posterior_boundary_total_mse = 0, 0, 0
    model.eval()
    correct = 0.0
    total = 0.0
    curtime = time()
    print(f'Calculating expectation')

    for i, item in enumerate(dataloader):
        with torch.no_grad():
            item_0 = item[0].to(device) # features
            item_1 = item[1].to(device) # count
            item_2 = item[2].to(device) # gt frame-wise labels
            item_4 = item[4] # video-ids
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
            middle_pred, predictions = model(item_0, src_mask_mse)
            prob = torch.softmax(predictions[-1], dim=1)
            prob = prob.permute(0, 2, 1)
            
            calculate_element_probb(prob, item_1, item_4, item_2)
            if (i+1) % 10 == 0:
                print(f"iter {i+1} of Expectation completed in a total of {(time() - curtime)/60.: .1f} minutes")
    _ = model.train()
    print(f'Expectation step finished, '
          f'posterior frame-wise accuracy {100*posterior_acc_correct/posterior_acc_total: .2f}%, '
          f'boundary mse {(posterior_boundary_total_mse/num_boundary)**0.5: .2f}')
    return

In [17]:
set_seed()
model = MultiStageModel(num_stages=4, num_layers=10, num_f_maps=64, dim=2048, num_classes=48).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Requires loaded_vidid_selected_frames, boundaries_dict
ce_criterion = nn.CrossEntropyLoss(ignore_index=-100)
mse_criterion = nn.MSELoss(reduction='none')

In [24]:
# loaded_file=torch.load(os.path.join(config.output_dir, "ms-tcn-initial-30-epochs.wt"))
# model.load_state_dict(loaded_file)
# # loaded_file=torch.load('/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcnnew-full-supervised-split1/ms-tcn-best-model.wt')
# # model.load_state_dict(loaded_file)
# loaded_file=torch.load(os.path.join(config.output_dir, "ms-tcn-emmax-best-model.wt"))
loaded_file=torch.load("/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split1/ms-tcn-emmax-last-model.wt")
model.load_state_dict(loaded_file)
_ = validate(model, testloader)

Calculating Validation Data Accuracy
Validation:: Probability Accuracy 68.48277281163068
Other scores:: Edit 68.65780790462246, F1@[10:25:50] [68.49952913 64.92506301 53.48251224]


In [18]:
# item = next(iter(trainloader))
    
# with torch.no_grad():
#     item_0 = item[0].to(device) # features
#     item_1 = item[1].to(device) # count
#     item_2 = item[2].to(device) # gt frame-wise labels
#     item_4 = item[4] # video-ids
#     src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
#     src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
#     middle_pred, predictions = model(item_0, src_mask_mse)
#     prob = torch.softmax(predictions[-1], dim=1)
#     prob = prob.permute(0, 2, 1)

#     res = calculate_element_probb(prob, item_1, item_4, item_2)

In [19]:
# idx = 2
# vidid = res[idx][0]
# mat = res[idx][1]
# mat.shape

In [20]:
# np.linspace(0, 5281, 4 + 1).astype(int)

In [21]:
# boundary_frames_dict[f'{vidid}.txt'], selected_frames_dict[f'{vidid}.txt'], weakly_labels[f'{vidid}.txt']

In [22]:
# fig = plt.figure(figsize=(20, 5))
# for i in range(48):
#     plt.plot(mat[:,i])
    
# for bd in boundary_frames_dict[f'{vidid}.txt']:
#     plt.plot([bd, bd], [0, 2])
    
# for bd in selected_frames_dict[f'{vidid}.txt']:
#     plt.plot([bd[0], bd[0]], [0, 2], '--')

In [23]:
# bd

In [24]:
# Calculating Expectation Step
# perform_expectation(model, trainloder_expectation)

In [25]:
def get_single_random(video_ids, len_frames, device):
    # Generate target for only timestamps. Do not generate pseudo labels at first 30 epochs.
    boundary_target_tensor = torch.ones((len(video_ids), len_frames), dtype=torch.long, device=device) * (-100)
    for iter_num, cur_vidid in enumerate(video_ids):
        selected_frames_indices_and_labels = selected_frames_dict[cur_vidid + ".txt"]
        selected_frames_indices = [ele[0] for ele in selected_frames_indices_and_labels]
        selected_frames_labels = [label_name_to_label_id_dict[ele[1]] for ele in selected_frames_indices_and_labels]

        frame_idx_tensor = torch.from_numpy(np.array(selected_frames_indices))
        frame_labels = torch.from_numpy(np.array(selected_frames_labels)).to(device)
        boundary_target_tensor[iter_num, frame_idx_tensor] = frame_labels

    return boundary_target_tensor

In [26]:
weakly_labels = pickle.load(open("data/breakfast_weaklysupervised_labels.pkl", "rb"))
prior_probs = pickle.load(open('data/breakfast_lengthmodel_multinomial_prior.pkl', 'rb'))

In [27]:
initialize_epoch = 30
expectation_cal_gap = 5
best_val_acc = 0
for epoch in range(0, 150):
    print("Starting Training")
    model.train()
    for i, item in enumerate(trainloader):
        item_0 = item[0].to(device)  # features
        item_1 = item[1].to(device)  # count
        item_2 = item[2].to(device)  # target
        weights = item[5].to(device)  # posterior weight
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
        optimizer.zero_grad()
        
        middle_pred, predictions = model(item_0, src_mask_mse)
        boundary_target_tensor = get_single_random(item[4], item_2.shape[1], item_2.device)
        
        loss = 0
        for p in predictions:
            if epoch <= initialize_epoch:
                loss += ce_criterion(p, boundary_target_tensor)
                loss += 0.15 * torch.mean(torch.clamp(mse_criterion(F.log_softmax(p[:, :, 1:], dim=1), 
                                                                    F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
                                            max=16) * src_mask_mse[:, :, 1:])
            else:
                prob = torch.softmax(p, dim=1)
                prob = prob.permute(0, 2, 1)
                total_count = torch.sum(src_mask)
                weighted_loss_sum = -torch.sum(torch.sum(torch.log(prob + 1e-8) * weights, dim=-1) * src_mask)
                loss += weighted_loss_sum/total_count

        loss.backward()
        optimizer.step()
        if (i+1)%20 == 0:
            print(f'Epoch {epoch+1}: Iteration {i+1} with loss {loss.item()}')

    if (epoch >= initialize_epoch) and ((epoch % (3 * expectation_cal_gap)) == 0):
        torch.save(model.state_dict(), config.output_dir + f"ms-tcn-initial-{epoch}-epochs.wt")

    if epoch >= initialize_epoch and (epoch % expectation_cal_gap == 0):
        perform_expectation(model, trainloder_expectation)
    
    print(f'Epoch {epoch+1} finished, starting validation')
    val_acc, best_val_acc = validate(model, testloader, best_val_acc)


Starting Training
Epoch 1: Iteration 20 with loss 14.473981857299805
Epoch 1: Iteration 40 with loss 14.196702003479004
Epoch 1: Iteration 60 with loss 13.767776489257812
Epoch 1: Iteration 80 with loss 12.938998222351074
Epoch 1: Iteration 100 with loss 12.450570106506348
Epoch 1: Iteration 120 with loss 10.387760162353516
Epoch 1: Iteration 140 with loss 10.396161079406738
Epoch 1: Iteration 160 with loss 9.988459587097168
Epoch 1: Iteration 180 with loss 10.840205192565918
Epoch 1 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 22.554815579852082
Starting Training
Epoch 2: Iteration 20 with loss 9.666767120361328
Epoch 2: Iteration 40 with loss 9.755131721496582
Epoch 2: Iteration 60 with loss 9.738783836364746
Epoch 2: Iteration 80 with loss 9.555734634399414
Epoch 2: Iteration 100 with loss 8.562029838562012
Epoch 2: Iteration 120 with loss 9.58092212677002
Epoch 2: Iteration 140 with loss 7.183716297149658
Epoch 2: Iteration 16

Epoch 14: Iteration 140 with loss 2.839608907699585
Epoch 14: Iteration 160 with loss 3.190566301345825
Epoch 14: Iteration 180 with loss 2.9375979900360107
Epoch 14 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 59.25780832650735
Starting Training
Epoch 15: Iteration 20 with loss 3.808217763900757
Epoch 15: Iteration 40 with loss 2.941004991531372
Epoch 15: Iteration 60 with loss 2.2290096282958984
Epoch 15: Iteration 80 with loss 2.2440226078033447
Epoch 15: Iteration 100 with loss 3.1044468879699707
Epoch 15: Iteration 120 with loss 1.3804901838302612
Epoch 15: Iteration 140 with loss 2.4851391315460205
Epoch 15: Iteration 160 with loss 2.0349082946777344
Epoch 15: Iteration 180 with loss 1.2924021482467651
Epoch 15 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 62.75646885177139
Starting Training
Epoch 16: Iteration 20 with loss 2.2586536407470703
Epoch 16: Iteration 40 with 

Validation:: Probability Accuracy 61.1372279006454
Starting Training
Epoch 28: Iteration 20 with loss 1.5904126167297363
Epoch 28: Iteration 40 with loss 1.4188443422317505
Epoch 28: Iteration 60 with loss 1.4746094942092896
Epoch 28: Iteration 80 with loss 0.9514474868774414
Epoch 28: Iteration 100 with loss 1.2625845670700073
Epoch 28: Iteration 120 with loss 0.8328320980072021
Epoch 28: Iteration 140 with loss 0.6400932669639587
Epoch 28: Iteration 160 with loss 0.9987165331840515
Epoch 28: Iteration 180 with loss 1.1121193170547485
Epoch 28 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 61.94229772348651
Starting Training
Epoch 29: Iteration 20 with loss 0.7593861222267151
Epoch 29: Iteration 40 with loss 1.008871078491211
Epoch 29: Iteration 60 with loss 0.8845118284225464
Epoch 29: Iteration 80 with loss 0.9964534044265747
Epoch 29: Iteration 100 with loss 0.5189375281333923
Epoch 29: Iteration 120 with loss 1.5196924209594727

Epoch 39: Iteration 140 with loss 1.5075066089630127
Epoch 39: Iteration 160 with loss 0.7276022434234619
Epoch 39: Iteration 180 with loss 1.1569093465805054
Epoch 39 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 65.28425751154481
Starting Training
Epoch 40: Iteration 20 with loss 0.966759204864502
Epoch 40: Iteration 40 with loss 1.0238419771194458
Epoch 40: Iteration 60 with loss 0.5241172313690186
Epoch 40: Iteration 80 with loss 1.1334779262542725
Epoch 40: Iteration 100 with loss 1.9233431816101074
Epoch 40: Iteration 120 with loss 2.5904903411865234
Epoch 40: Iteration 140 with loss 2.063493490219116
Epoch 40: Iteration 160 with loss 1.2196705341339111
Epoch 40: Iteration 180 with loss 1.3233022689819336
Epoch 40 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 64.53241054010311
Starting Training
Epoch 41: Iteration 20 with loss 0.8040460348129272
Epoch 41: Iteration 40 wit

Epoch 51: Iteration 40 with loss 1.4721462726593018
Epoch 51: Iteration 60 with loss 0.5729708075523376
Epoch 51: Iteration 80 with loss 0.8798316717147827
Epoch 51: Iteration 100 with loss 0.566657543182373
Epoch 51: Iteration 120 with loss 0.5856693983078003
Epoch 51: Iteration 140 with loss 0.5163483619689941
Epoch 51: Iteration 160 with loss 0.6261053681373596
Epoch 51: Iteration 180 with loss 1.5823124647140503
Calculating expectation
iter 10 of Expectation completed in a total of  3.4 minutes
iter 20 of Expectation completed in a total of  6.6 minutes
iter 30 of Expectation completed in a total of  9.8 minutes
iter 40 of Expectation completed in a total of  12.8 minutes
iter 50 of Expectation completed in a total of  15.8 minutes
iter 60 of Expectation completed in a total of  19.2 minutes
iter 70 of Expectation completed in a total of  22.7 minutes
Expectation step finished, posterior frame-wise accuracy  77.56%, boundary mse  0.00
Epoch 51 finished, starting validation
Calculat

Validation:: Probability Accuracy 63.17908599150809
Starting Training
Epoch 62: Iteration 20 with loss 0.4145604074001312
Epoch 62: Iteration 40 with loss 0.5906035304069519
Epoch 62: Iteration 60 with loss 0.46675607562065125
Epoch 62: Iteration 80 with loss 0.8086108565330505
Epoch 62: Iteration 100 with loss 0.5134670734405518
Epoch 62: Iteration 120 with loss 0.6713782548904419
Epoch 62: Iteration 140 with loss 0.557614266872406
Epoch 62: Iteration 160 with loss 0.5568332672119141
Epoch 62: Iteration 180 with loss 1.9672114849090576
Epoch 62 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 63.41532422411371
Starting Training
Epoch 63: Iteration 20 with loss 1.3137223720550537
Epoch 63: Iteration 40 with loss 1.452265977859497
Epoch 63: Iteration 60 with loss 1.093142032623291
Epoch 63: Iteration 80 with loss 0.7394229769706726
Epoch 63: Iteration 100 with loss 0.8561142086982727
Epoch 63: Iteration 120 with loss 0.8012918829917908

Epoch 73: Iteration 120 with loss 0.3089704215526581
Epoch 73: Iteration 140 with loss 0.3530327081680298
Epoch 73: Iteration 160 with loss 0.5347459316253662
Epoch 73: Iteration 180 with loss 0.21837174892425537
Epoch 73 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 63.068485344919694
Starting Training
Epoch 74: Iteration 20 with loss 1.1085840463638306
Epoch 74: Iteration 40 with loss 0.3893837630748749
Epoch 74: Iteration 60 with loss 0.43246352672576904
Epoch 74: Iteration 80 with loss 0.6669690012931824
Epoch 74: Iteration 100 with loss 0.3250642716884613
Epoch 74: Iteration 120 with loss 0.4407763183116913
Epoch 74: Iteration 140 with loss 0.5543665289878845
Epoch 74: Iteration 160 with loss 0.28781282901763916
Epoch 74: Iteration 180 with loss 0.29603180289268494
Epoch 74 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 62.971536656496944
Starting Training
Epoch 75: Iterati

Validation:: Probability Accuracy 62.135205828001155
Starting Training
Epoch 85: Iteration 20 with loss 0.2286156564950943
Epoch 85: Iteration 40 with loss 0.2455473244190216
Epoch 85: Iteration 60 with loss 0.13663890957832336
Epoch 85: Iteration 80 with loss 0.24196581542491913
Epoch 85: Iteration 100 with loss 0.2452765703201294
Epoch 85: Iteration 120 with loss 0.30068597197532654
Epoch 85: Iteration 140 with loss 0.802282989025116
Epoch 85: Iteration 160 with loss 2.0211524963378906
Epoch 85: Iteration 180 with loss 0.9227396845817566
Epoch 85 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 61.220920339834834
Starting Training
Epoch 86: Iteration 20 with loss 0.8459664583206177
Epoch 86: Iteration 40 with loss 2.2170650959014893
Epoch 86: Iteration 60 with loss 0.6469560265541077
Epoch 86: Iteration 80 with loss 0.6061407923698425
Epoch 86: Iteration 100 with loss 0.759093701839447
Epoch 86: Iteration 120 with loss 0.46105247735

Epoch 96: Iteration 120 with loss 0.38040298223495483
Epoch 96: Iteration 140 with loss 0.4900578260421753
Epoch 96: Iteration 160 with loss 0.4218160808086395
Epoch 96: Iteration 180 with loss 0.5010716915130615
Calculating expectation
iter 10 of Expectation completed in a total of  3.7 minutes
iter 20 of Expectation completed in a total of  6.9 minutes
iter 30 of Expectation completed in a total of  10.5 minutes
iter 40 of Expectation completed in a total of  13.6 minutes
iter 50 of Expectation completed in a total of  16.6 minutes
iter 60 of Expectation completed in a total of  19.9 minutes
iter 70 of Expectation completed in a total of  23.1 minutes
Expectation step finished, posterior frame-wise accuracy  73.96%, boundary mse  0.00
Epoch 96 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 59.307469797515736
Starting Training
Epoch 97: Iteration 20 with loss 0.3176898956298828
Epoch 97: Iteration 40 with loss 0.7296744585037231
Ep

Epoch 107: Iteration 20 with loss 0.8024297952651978
Epoch 107: Iteration 40 with loss 1.6126179695129395
Epoch 107: Iteration 60 with loss 0.7645372748374939
Epoch 107: Iteration 80 with loss 0.6612119078636169
Epoch 107: Iteration 100 with loss 0.33929723501205444
Epoch 107: Iteration 120 with loss 1.428182601928711
Epoch 107: Iteration 140 with loss 0.9542698860168457
Epoch 107: Iteration 160 with loss 0.8313663005828857
Epoch 107: Iteration 180 with loss 1.0668387413024902
Epoch 107 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 57.94642892474012
Starting Training
Epoch 108: Iteration 20 with loss 6.306822776794434
Epoch 108: Iteration 40 with loss 0.8399178981781006
Epoch 108: Iteration 60 with loss 0.6756889224052429
Epoch 108: Iteration 80 with loss 0.45695963501930237
Epoch 108: Iteration 100 with loss 0.7192859649658203
Epoch 108: Iteration 120 with loss 0.4664561152458191
Epoch 108: Iteration 140 with loss 0.74625813961029

Epoch 118: Iteration 100 with loss 0.3372148275375366
Epoch 118: Iteration 120 with loss 0.33542126417160034
Epoch 118: Iteration 140 with loss 0.1804150938987732
Epoch 118: Iteration 160 with loss 0.2850167453289032
Epoch 118: Iteration 180 with loss 0.45846375823020935
Epoch 118 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 59.72256846753802
Starting Training
Epoch 119: Iteration 20 with loss 0.3154414892196655
Epoch 119: Iteration 40 with loss 0.5499111413955688
Epoch 119: Iteration 60 with loss 0.18084734678268433
Epoch 119: Iteration 80 with loss 0.28059494495391846
Epoch 119: Iteration 100 with loss 0.2876918911933899
Epoch 119: Iteration 120 with loss 0.17696411907672882
Epoch 119: Iteration 140 with loss 0.2986498177051544
Epoch 119: Iteration 160 with loss 0.4450201988220215
Epoch 119: Iteration 180 with loss 0.41744956374168396
Epoch 119 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Proba

Epoch 129: Iteration 180 with loss 0.1836945116519928
Epoch 129 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 59.53005607195571
Starting Training
Epoch 130: Iteration 20 with loss 0.16347678005695343
Epoch 130: Iteration 40 with loss 0.16375412046909332
Epoch 130: Iteration 60 with loss 0.123257577419281
Epoch 130: Iteration 80 with loss 0.1983143389225006
Epoch 130: Iteration 100 with loss 0.19463422894477844
Epoch 130: Iteration 120 with loss 0.1917058527469635
Epoch 130: Iteration 140 with loss 0.16564421355724335
Epoch 130: Iteration 160 with loss 0.18553827702999115
Epoch 130: Iteration 180 with loss 0.2273888885974884
Epoch 130 finished, starting validation
Calculating Validation Data Accuracy
Validation:: Probability Accuracy 59.71702854248529
Starting Training
Epoch 131: Iteration 20 with loss 0.13815267384052277
Epoch 131: Iteration 40 with loss 0.16333693265914917
Epoch 131: Iteration 60 with loss 0.2679365873336792
Epoch

Epoch 141: Iteration 20 with loss 0.1790236234664917
Epoch 141: Iteration 40 with loss 0.2252202033996582
Epoch 141: Iteration 60 with loss 0.16876894235610962
Epoch 141: Iteration 80 with loss 0.1765454262495041
Epoch 141: Iteration 100 with loss 0.16316133737564087
Epoch 141: Iteration 120 with loss 0.14833563566207886
Epoch 141: Iteration 140 with loss 0.15401741862297058
Epoch 141: Iteration 160 with loss 0.2950907349586487
Epoch 141: Iteration 180 with loss 0.12430804967880249
Calculating expectation
iter 10 of Expectation completed in a total of  3.3 minutes
iter 20 of Expectation completed in a total of  6.3 minutes
iter 40 of Expectation completed in a total of  11.4 minutes


Process ForkPoolWorker-34195:
Process ForkPoolWorker-34188:
Process ForkPoolWorker-34185:
Process ForkPoolWorker-34181:
Process ForkPoolWorker-34193:
Process ForkPoolWorker-34178:
Process ForkPoolWorker-34179:
Process ForkPoolWorker-34189:
Process ForkPoolWorker-34182:
Process ForkPoolWorker-34177:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkPoolWorker-34187:
Process ForkPoolWorker-34194:
Process ForkPoolWorker-34186:
Process ForkPoolWorker-34184:
Process ForkPoolWorker-34190:
Process ForkPoolWorker-34191:
Process ForkPoolWorker-34196:
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process ForkPoolWorker-34180:
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/mult

  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
KeyboardInterrupt
KeyboardInterrupt
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/queues.py", line 351, in get
    with self._rlock:
KeyboardInterrupt
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    

Traceback (most recent call last):
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-27-da378ef86060>", line 42, in <module>
    perform_expectation(model, trainloder_expectation)
  File "<ipython-input-15-e151fbc59213>", line 66, in perform_expectation
    calculate_element_probb(prob, item_1, item_4, item_2)
  File "<ipython-input-15-e151fbc59213>", line 42, in calculate_element_probb
    pool.join()  # postpones the execution of next line of code until all processes in the queue are done.
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/pool.py", line 556, in join
    self._worker_handler.join()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/threading.py", line 1032, in join
    self._wait_for_tstate_lock()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/threading.py", line 1

Traceback (most recent call last):


KeyboardInterrupt: 

Traceback (most recent call last):
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/dipika16/anaconda3/envs/video_r/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._t

In [22]:
print(f"Validation:: Epoch {epoch}, Probability Accuracy {val_acc}")

Validation:: Epoch 105, Probability Accuracy 61.02425300046298


In [23]:
print(f"Validation:: Epoch {epoch}, Probability Accuracy {best_val_acc}")

Validation:: Epoch 105, Probability Accuracy 63.70656599831428


In [24]:
torch.save(model.state_dict(),
"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-speed/final-em-maximized.wt")

In [24]:
config.output_dir

'/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split1/'

In [34]:
model.load_state_dict(torch.load(config.output_dir + "ms-tcn-emmax-best-model.wt"))
# model.load_state_dict(torch.load(config.output_dir + "ms-tcn-initial-15-epochs.wt"))

<All keys matched successfully>