In [None]:
from google.colab import drive 
drive.mount("/content/drive/")

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

# Datasets

class MultimodalDataset(Dataset):
  def __init__(self, data, labels):
    self.data = data
    self.labels = labels
    self.num_modalities = len(self.data)
  
  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    return tuple([self.data[i][idx] for i in range(self.num_modalities)] + [self.labels[idx]])

# Models

batch_size = 256

def sinkhorn_probs(matrix, x1_probs, x2_probs):
    matrix = matrix / (torch.sum(matrix, dim=0, keepdim=True) + 1e-8) * x2_probs[None]
    sum = torch.sum(matrix, dim=1)
    if torch.allclose(sum, x1_probs, rtol=0, atol=0.01):
        return matrix, True
    matrix = matrix / (torch.sum(matrix, dim=1, keepdim=True) + 1e-8) * x1_probs[:, None]
    sum = torch.sum(matrix, dim=0)
    if torch.allclose(sum, x2_probs, rtol=0, atol=0.01):
        return matrix, True
    return matrix, False

def mlp(dim, hidden_dim, output_dim, layers, activation):
    activation = {
        'relu': nn.ReLU,
        'tanh': nn.Tanh,
    }[activation]

    seq = [nn.Linear(dim, hidden_dim), activation()]
    for _ in range(layers):
        seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
    seq += [nn.Linear(hidden_dim, output_dim)]

    return nn.Sequential(*seq)

def simple_discrim(xs, y, num_labels):
    shape = [x.size(1) for x in xs] + [num_labels]
    p = torch.ones(*shape) * 1e-8
    for i in range(len(y)):
        p[tuple([torch.argmax(x[i]).item() for x in xs] + [y[i].item()])] += 1
    p /= torch.sum(p)
    p = p.cuda()
    
    def f(*x):
        x = [torch.argmax(xx, dim=1) for xx in x]
        return torch.log(p[tuple(x)])

    return f

class Discrim(nn.Module):
    def __init__(self, x_dim, hidden_dim, num_labels, layers, activation):
        super().__init__()
        self.mlp = mlp(x_dim, hidden_dim, num_labels, layers, activation)
    def forward(self, *x):
        x = torch.cat(x, dim=-1)
        return self.mlp(x)

class CEAlignment(nn.Module):
    def __init__(self, x1_dim, x2_dim, hidden_dim, embed_dim, num_labels, layers, activation):
        super().__init__()

        self.num_labels = num_labels
        self.mlp1 = mlp(x1_dim, hidden_dim, embed_dim * num_labels, layers, activation)
        self.mlp2 = mlp(x2_dim, hidden_dim, embed_dim * num_labels, layers, activation)

    def forward(self, x1, x2, x1_probs, x2_probs):
        x1_input = x1
        x2_input = x2

        q_x1 = self.mlp1(x1).unflatten(1, (self.num_labels, -1))
        q_x2 = self.mlp2(x2).unflatten(1, (self.num_labels, -1))

        q_x1 = (q_x1 - torch.mean(q_x1, dim=2, keepdim=True)) / torch.sqrt(torch.var(q_x1, dim=2, keepdim=True) + 1e-8)
        q_x2 = (q_x2 - torch.mean(q_x2, dim=2, keepdim=True)) / torch.sqrt(torch.var(q_x2, dim=2, keepdim=True) + 1e-8)

        # print(q_x1)

        align = torch.einsum('ahx, bhx -> abh', q_x1, q_x2) / math.sqrt(q_x1.size(-1))
        # print(q_x1[0])
        # print(q_x2[0])
        # print(q_x1[0] * q_x2[0])
        # print(torch.sum(q_x1[0] * q_x2[0]))
        # print(align)
        align_logits = align
        align = torch.exp(align)
        # print(x1_input[:10])
        # print(x2_input[:10])
        # print(x1_input[:, 0])
        # print(x2_input[:, 0][None])
        # print(x1_input[:, 0, None] == x2_input[:, 0][None])
        # align = (x1_input[:, 0, None] == x2_input[:, 0][None]) + align - align.detach()

        # print(align[:10, :10])

        normalized = []
        for i in range(align.size(-1)):
            current = align[..., i]
            for j in range(500): # TODO
                current, stop = sinkhorn_probs(current, x1_probs[:, i], x2_probs[:, i])
                if stop:
                    break
            normalized.append(current)
        normalized = torch.stack(normalized, dim=-1)

        if torch.any(torch.isnan(normalized)):
            print(align_logits)
            print(align)
            print(normalized)
            raise Exception('nan')

        return normalized

class CEAlignmentInformation(nn.Module):
    def __init__(self, x1_dim, x2_dim, hidden_dim, embed_dim, num_labels,
                 layers, activation, discrim_1, discrim_2, discrim_12, p_y):
        super().__init__()
        self.num_labels = num_labels
        self.align = CEAlignment(x1_dim, x2_dim, hidden_dim, embed_dim, num_labels, layers, activation)
        self.discrim_1 = discrim_1
        if isinstance(self.discrim_1, nn.Module):
            self.discrim_1.eval()
        self.discrim_2 = discrim_2
        if isinstance(self.discrim_2, nn.Module):
            self.discrim_2.eval()
        self.discrim_12 = discrim_12
        if isinstance(self.discrim_12, nn.Module):
            self.discrim_12.eval()
        self.register_buffer('p_y', p_y)
        # self.critic_1y = SeparableCritic(x1_dim, y_dim, hidden_dim, embed_dim, layers, activation)
        # self.critic_2y = SeparableCritic(x2_dim, y_dim, hidden_dim, embed_dim, layers, activation)
        # self.critic_12y = SeparableCritic(x1_dim + x2_dim, y_dim, hidden_dim, embed_dim, layers, activation)

    def align_parameters(self):
        return list(self.align.parameters())

    def forward(self, x1, x2, y):
        # print('forward', x1.shape, x2.shape, y.shape)
        with torch.no_grad():
            a = self.discrim_1([x1])
            # print('a', a.shape)
            b = self.discrim_2([x2])
            # print('b', b.shape)
            p_y_x1 = nn.Softmax(dim=-1)(a)
            p_y_x2 = nn.Softmax(dim=-1)(b)
        align = self.align(torch.flatten(x1, 1, -1), torch.flatten(x2, 1, -1), p_y_x1, p_y_x2)
        # print(p_y_x2)
        # print(self.p_y)
        # print(y.squeeze(-1))
        y = nn.functional.one_hot(y.squeeze(-1).long(), num_classes=self.num_labels)
        self.p_y[self.p_y == 0] += 1e-8
        self.p_y[self.p_y == 1] -= 1e-8

        # sample method: P(X1)
        # coeff: P(Y | X1) Q(X2 | X1, Y)
        # log term: log Q(X2 | X1, Y) - logsum_Y' Q(X2 | X1, Y') Q(Y' | X1)

        q_x2_x1y = align / (torch.sum(align, dim=1, keepdim=True) + 1e-8)
        # print(torch.cat([1 - y, y], dim=-1).shape)
        log_term = torch.log(q_x2_x1y + 1e-8) - torch.log(torch.einsum('aby, ay -> ab', q_x2_x1y, p_y_x1) + 1e-8)[:, :, None]
        # print(q_x2_x1y)
        # print(log_term)
        # That's all we need for optimization purposes
        loss = torch.mean(torch.sum(torch.sum(p_y_x1[:, None, :] * q_x2_x1y * log_term, dim=-1), dim=-1))
        # Now, we calculate the MI terms
        p_y_x1_sampled = torch.sum(p_y_x1 * y, dim=-1)
        p_y_x2_sampled = torch.sum(p_y_x2 * y, dim=-1)
        # print(p_y_x2_sampled)
        with torch.no_grad():
            p_y_x1x2 = nn.Softmax(dim=-1)(self.discrim_12([x1, x2]))
        p_y_x1x2_sampled = torch.sum(p_y_x1x2 * y, dim=-1)
        p_y_sampled = torch.sum(self.p_y[None] * y, dim=-1)

        p1 = p_y_x1.detach().clone()
        p1[p1 == 0] += 1e-8
        log_p_y_x1 = torch.log(p1)
        # log_p_y_x1[log_p_y_x1 == float("-Inf")] += 1e-8
        p2 = p_y_x2.detach().clone()
        p2[p2 == 0] += 1e-8
        log_p_y_x2 = torch.log(p2)
        # log_p_y_x2[log_p_y_x2 == float("-Inf")] += 1e-8
        p12 = p_y_x1x2.detach().clone()
        p12[p12 == 0] += 1e-8
        log_p_y_x1x2 = torch.log(p12)
        # log_p_y_x1x2[log_p_y_x1x2 == float("-Inf")] += 1e-8

        # mi_y_x1 = torch.mean(torch.log(p_y_x1_sampled) - torch.log(p_y_sampled))
        mi_y_x1 = torch.mean(torch.sum(p_y_x1 * (log_p_y_x1 - torch.log(self.p_y)[None]), dim=-1))
        # mi_y_x2 = torch.mean(torch.log(p_y_x2_sampled) - torch.log(p_y_sampled))
        mi_y_x2 = torch.mean(torch.sum(p_y_x2 * (log_p_y_x2 - torch.log(self.p_y)[None]), dim=-1))
        # mi_y_x1x2 = torch.mean(torch.log(p_y_x1x2_sampled) - torch.log(p_y_sampled))
        mi_y_x1x2 = torch.mean(torch.sum(p_y_x1x2 * (log_p_y_x1x2 - torch.log(self.p_y)[None, None]), dim=-1))
        mi_q_y_x1x2 = p_y_x1[:, None, :] * q_x2_x1y * (log_term + torch.log(p_y_x1 + 1e-8)[:, None, :] - torch.log(self.p_y + 1e-8)[None, None, :])
        '''
        if not self.training:
            print(p_y_x1)
            print(q_x2_x1y)
            print(log_term)
            print(torch.log(p_y_x1))
            print(torch.log(self.p_y))
            print(log_term + torch.log(p_y_x1)[:, None, :] - torch.log(self.p_y)[None, None, :])
        '''
        mi_q_y_x1x2 = torch.sum(torch.sum(mi_q_y_x1x2, dim=-1), dim=-1) # anchored by x1 -- take mean to get MI
        mi_q_y_x1x2 = torch.mean(mi_q_y_x1x2)

        '''
        if not self.training:
            print(torch.stack([mi_y_x1, mi_y_x2, mi_y_x1x2, mi_q_y_x1x2]))
        '''
        # print('   m', torch.stack([mi_y_x1, mi_y_x2, mi_y_x1x2, mi_q_y_x1x2]))

        redundancy = mi_y_x1 + mi_y_x2 - mi_q_y_x1x2
        unique1 = mi_q_y_x1x2 - mi_y_x2
        unique2 = mi_q_y_x1x2 - mi_y_x1
        synergy = mi_y_x1x2 - mi_q_y_x1x2

        # print('   r', torch.stack([redundancy, unique1, unique2, synergy]))

        return loss, torch.stack([redundancy, unique1, unique2, synergy], dim=0), align

# Training Loops
from tqdm import tqdm
def train_discrim(model, train_loader, optimizer, data_type, num_epoch=40):
    for _iter in range(num_epoch):
        print(_iter)
        for i_batch, data_batch in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()

            inputs = []
            for j in range(len(data_type)):
                xs = [data_batch[data_type[j][i] - 1] for i in range(len(data_type[j]))]
                x_batch = torch.cat(xs, dim=1).cuda()
                if j != len(data_type) - 1:
                    x_batch = x_batch.float()
                inputs.append(x_batch)
            y = inputs[-1]
            inputs = inputs[:-1]

            logits = model(*inputs)
            loss = nn.CrossEntropyLoss()(logits, y.squeeze(-1))
            loss.backward()

            optimizer.step()

            if (_iter + 1) % 20 == 0 and i_batch % 1024 == 0:
                print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())

def eval_discrim(model, test_loader, data_type):
    losses = []
    for i_batch, data_batch in enumerate(test_loader):
        inputs = []
        for j in range(len(data_type)):
            xs = [data_batch[data_type[j][i] - 1] for i in range(len(data_type[j]))]
            x_batch = torch.cat(xs, dim=1).cuda()
            if j != len(data_type) - 1:
                x_batch = x_batch.float()
            inputs.append(x_batch)
        y = inputs[-1]
        inputs = inputs[:-1]

        logits = model(*inputs)
        loss = nn.CrossEntropyLoss()(logits, y.squeeze(-1))
        losses.append(loss.item())

        if i_batch % 1024 == 0:
            print('i_batch: ', i_batch, ' loss: ', loss.item())
    print('Eval loss:', sum(losses) / len(losses))

def train_ce_alignment(model, train_loader, opt_align, data_type, num_epoch=10):
    for _iter in range(num_epoch):
        print(_iter)
        for i_batch, data_batch in enumerate(tqdm(train_loader)):
            opt_align.zero_grad()

            x1s = [data_batch[data_type[0][i] - 1] for i in range(len(data_type[0]))]
            x2s = [data_batch[data_type[1][i] - 1] for i in range(len(data_type[1]))]
            ys = [data_batch[data_type[2][i] - 1] for i in range(len(data_type[2]))]

            x1_batch = torch.cat(x1s, dim=1).float().cuda()
            x2_batch = torch.cat(x2s, dim=1).float().cuda()
            y_batch = torch.cat(ys, dim=1).cuda()

            loss, _, _ = model(x1_batch, x2_batch, y_batch)
            loss.backward()

            opt_align.step()

            # if (_iter + 1) % 1 == 0 and i_batch % 1 == 0:
            #     print('iter: ', _iter, ' i_batch: ', i_batch, ' align_loss: ', loss.item())

def eval_ce_alignment(model, test_loader, data_type):
    results = []
    aligns = []

    for i_batch, data_batch in enumerate(test_loader):
        x1s = [data_batch[data_type[0][i] - 1] for i in range(len(data_type[0]))]
        x2s = [data_batch[data_type[1][i] - 1] for i in range(len(data_type[1]))]
        ys = [data_batch[data_type[2][i] - 1] for i in range(len(data_type[2]))]

        x1_batch = torch.cat(x1s, dim=1).float().cuda()
        x2_batch = torch.cat(x2s, dim=1).float().cuda()
        y_batch = torch.cat(ys, dim=1).cuda()

        with torch.no_grad():
            _, result, align = model(x1_batch, x2_batch, y_batch)
        results.append(result)
        aligns.append(align)

    results = torch.stack(results, dim=0)
 
    return results, aligns

def critic_ce_alignment(x1, x2, labels, num_labels, train_ds, test_ds, discrim_1=None, discrim_2=None, discrim_12=None, learned_discrim=True, shuffle=True, discrim_epochs=40, ce_epochs=10):
    if discrim_1 is not None:
        model_discrim_1, model_discrim_2, model_discrim_12 = discrim_1, discrim_2, discrim_12
    elif learned_discrim:
        model_discrim_1 = Discrim(x_dim=x1.size(1), hidden_dim=32, num_labels=num_labels, layers=3, activation='relu').cuda()
        model_discrim_2 = Discrim(x_dim=x2.size(1), hidden_dim=32, num_labels=num_labels, layers=3, activation='relu').cuda()
        model_discrim_12 = Discrim(x_dim=x1.size(1) + x2.size(1), hidden_dim=32, num_labels=num_labels, layers=3, activation='relu').cuda()

        for model, data_type in [
            (model_discrim_1, ([1], [0])),
            (model_discrim_2, ([2], [0])),
            (model_discrim_12, ([1], [2], [0])),
        ]:
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            train_loader1 = DataLoader(train_ds, shuffle=shuffle, drop_last=True,
                                    batch_size=batch_size,
                                    num_workers=1)
            train_discrim(model, train_loader1, optimizer, data_type=data_type, num_epoch=discrim_epochs)
            model.eval()
            test_loader1 = DataLoader(test_ds, shuffle=False, drop_last=False,
                                      batch_size=batch_size, num_workers=1)
            eval_discrim(model, test_loader1, data_type=data_type)
    else:
        model_discrim_1 = simple_discrim([x1], labels, num_labels)
        model_discrim_2 = simple_discrim([x2], labels, num_labels)
        model_discrim_12 = simple_discrim([x1, x2], labels, num_labels)

    p_y = torch.sum(nn.functional.one_hot(labels.squeeze(-1)), dim=0) / len(labels)
    # print(p_y)

    def product(x):
        return x[0] * product(x[1:]) if x else 1

    model = CEAlignmentInformation(x1_dim=product(x1.shape[1:]), x2_dim=product(x2.shape[1:]),
        hidden_dim=32, embed_dim=10, num_labels=num_labels, layers=3, activation='relu',
        discrim_1=model_discrim_1, discrim_2=model_discrim_2, discrim_12=model_discrim_12,
        p_y=p_y).cuda()
    opt_align = optim.Adam(model.align_parameters(), lr=1e-3)

    train_loader1 = DataLoader(train_ds, shuffle=shuffle, drop_last=True,
                               batch_size=batch_size,
                               num_workers=1)
    test_loader1 = DataLoader(test_ds, shuffle=False, drop_last=True,
                              batch_size=batch_size,
                              num_workers=1)

    # Train and estimate mutual information
    model.train()
    train_ce_alignment(model, train_loader1, opt_align, data_type=([1], [2], [0]), num_epoch=ce_epochs)

    model.eval()
    results, aligns = eval_ce_alignment(model, test_loader1, data_type=([1], [2], [0]))
    return results, aligns, (model, model_discrim_1, model_discrim_2, model_discrim_12, p_y)

### Making dataset pickles

In [None]:
traindata, validdata, testdata = get_dataloader('mosei_raw.pkl', modalities=[0,1,2], 
                                                robust_test=False, max_pad=True, 
                                                data_type='mosei', max_seq_len=50,
                                                batch_size=1058)

In [None]:
# humor 18550, 4050, 15000
# torch.Size([32, 50, 371])
# torch.Size([32, 50, 81])
# torch.Size([32, 50, 300])
# torch.Size([32, 1])
import pickle
i = 0
x0, x1, x2, lab = None, None, None, None
for j in testdata:
  if i == 0:
    x0, x1, x2, lab = j
  else:
    x0 = torch.concatenate((x0, j[0]), dim=0)
    x1 = torch.concatenate((x1, j[1]), dim=0)
    x2 = torch.concatenate((x2, j[2]), dim=0)
    lab = torch.concatenate((lab, j[3]), dim=0)
  i += 1

In [None]:
print(x0.shape, x1.shape, x2.shape, lab.shape)
N = x0.shape[0]
x0 = x0.reshape((N, -1))
x1 = x1.reshape((N, -1))
x2 = x2.reshape((N, -1))
print(x0.shape, x1.shape, x2.shape, lab.shape)

data = dict()
data['x0'] = x0
data['x1'] = x1
data['x2'] = x2
data['labels'] = lab

In [None]:
with open('mosei_test_ce_dataset.pkl', 'wb') as f:
  pickle.dump(data, f)

### Start of training pipeline

In [None]:
!pip install memory_profiler

In [None]:
import pickle
with open('mosei_valid_ce_dataset.pkl', 'rb') as f:
  valid_data = pickle.load(f)
with open('mosei_train_ce_dataset.pkl', 'rb') as f:
  train_data = pickle.load(f)
with open('mosei_test_ce_dataset.pkl', 'rb') as f:
  test_data = pickle.load(f)

In [None]:
with open(f'ce_preds/mosei_test_outer_02_pred.pkl', 'rb') as f:
  test_pred_labels = pickle.load(f)
with open(f'ce_preds/mosei_train_outer_02_pred.pkl', 'rb') as f:
  train_pred_labels = pickle.load(f)

In [None]:
def flatten(L):
  R = L.tolist()
  # for p in L:
  #   R.extend(list(p.numpy().tolist()))
  return [[r] for r in R]
test_data['labels'] = (flatten(test_pred_labels))
train_data['labels'] = (flatten(train_pred_labels))

In [None]:
import numpy as np
replace_neg = np.vectorize (lambda x: 0 if x <= 0 else 1)

def get_mm_dataset(modalities, data):
  L = []
  for mod in modalities:
    L.append(data[f'x{mod}'])
  labels = replace_neg(data['labels'])
  return MultimodalDataset(L, labels)

In [None]:

class ConcatEarly(nn.Module):
    """Concatenation of input data on dimension 2."""

    def __init__(self):
        """Initialize ConcatEarly Module."""
        super(ConcatEarly, self).__init__()

    def forward(self, modalities):
        """
        Forward Pass of ConcatEarly.
        
        :param modalities: An iterable of modalities to combine
        """
        return torch.cat(modalities, dim=1)

In [None]:
import torch
import sys
import os
sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from unimodals.common_models import GRU, MLP, Sequential, Identity, Linear  # noqa
from training_structures.Supervised_Learning import train, test
# from fusions.common_fusions import ConcatEarly  # noqa

mod1= 0
mod2 = 2
for modalities in [[mod1], [mod2], [mod1, mod2]]:
  s = ''.join([str(m) for m in modalities])
  modelpath = f'ce_preds/mosei_tmp_{s}.pt'
  print(modelpath)

  valid_ds = get_mm_dataset(modalities, valid_data)
  train_ds = get_mm_dataset(modalities, train_data)
  test_ds = get_mm_dataset(modalities, test_data)

  traindl = DataLoader(train_ds, shuffle=True, num_workers=2, batch_size=32)
  valdl = DataLoader(valid_ds, shuffle=False, num_workers=2, batch_size=32)
  testdl = DataLoader(test_ds, shuffle=False, num_workers=2, batch_size=32)
          

  for x in valdl:
    for y in x:
      print(y.shape)
    break

  dout = 0
  # dims = [18550, 4050, 15000]
  dims = [35650, 3700, 15000]
  for mod in modalities:
    dout += dims[mod]
  print(dout)

  encoders = [Identity().cuda(),Identity().cuda()]
  # head = Sequential(GRU(4050, 1128, dropout=True, has_padding=False, batch_first=True, last_only=True), MLP(1128, 512, 1)).cuda()
  # 18550, 4050, 15000
  head = Linear(dout, 2).cuda().cuda()

  fusion = ConcatEarly().cuda()

  train(encoders, fusion, head, traindl, valdl, 10, task="classification", optimtype=torch.optim.AdamW,
        is_packed=False, lr=1e-3, save=modelpath, weight_decay=0.01, 
        objective=torch.nn.CrossEntropyLoss())

  print("Testing:", modelpath)
  model = torch.load(modelpath).cuda()
  test(model, testdl, 'mosei', is_packed=False,
      criterion=torch.nn.CrossEntropyLoss(), task="classification", no_robust=True)


In [None]:
mod1 = 0
mod2 = 2
model1 = torch.load(f'ce_preds/mosei_tmp_{mod1}.pt').cuda()
model1.requires_grad = False
model1.eval()
model2 = torch.load(f'ce_preds/mosei_tmp_{mod2}.pt').cuda()
model2.requires_grad = False
model2.eval()
model12 = torch.load(f'ce_preds/mosei_tmp_{mod1}{mod2}.pt').cuda()
model12.requires_grad = False
model12.eval()

modalities = [mod1,mod2]

valid_ds = get_mm_dataset(modalities, valid_data)
train_ds = get_mm_dataset(modalities, train_data)
test_ds = get_mm_dataset(modalities, test_data)

In [None]:
import numpy as np

replace_neg = np.vectorize (lambda x: 0 if x <= 0 else 1)
pred = replace_neg(train_data['labels'])
# print(pred, len(pred))

results = critic_ce_alignment(train_data[f'x{mod1}'], train_data[f'x{mod2}'], torch.tensor(pred), 2, 
                    train_ds, test_ds, 
                    discrim_1=model1, discrim_2=model2, discrim_12=model12, 
                    learned_discrim=True, shuffle=True, discrim_epochs=40, ce_epochs=8)

In [None]:
import numpy as np
res = results[0].cpu().numpy()
values = np.mean(res, axis=0)
values = values/np.log(2)
values = np.maximum(values, 0)
print(', '.join([str(v) for v in values]))
print("Redundancy:", values[0])
print("Unique1:", values[1])
print("Unique1:", values[2])
print("Synergy:", values[3])

