In [3]:
# mlp torch model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import random

# import torch dataset and dataloader
from torch.utils.data import Dataset, DataLoader

In [4]:
class EnsembleDataset(Dataset):
    def __init__(self, df, split="train", chunk_size=128):
        self.split = split
        self.chunk_size = chunk_size
        self.df = df
        if split != "train":
            new_df = []
            for i, row in self.df.iterrows():
                # split features and target into chunks
                features = row["features"]
                targets = row["target"]
                n_chunks = len(features) // chunk_size
                # n_chunks = int(np.ceil(n_chunks))
                feature_chunks = np.array_split(features, n_chunks)
                target_chunks = np.array_split(targets, n_chunks)

                for feature, target in zip(feature_chunks, target_chunks):
                    new_df.append(
                        {
                            "file": row["file"],
                            "features": feature,
                            "target": target,
                        }
                    )
            self.df = pd.DataFrame.from_dict(new_df)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]

        X = item["features"]
        y = item["target"]

        if self.split == "train":
            # random choose chunk_size indices
            indices = random.sample(list(range(len(X))), self.chunk_size)
            X = X[indices]
            y = y[indices]

        return {
            "file": item["file"],
            "X": X,
            "y": y,
        }


def collate_fn(batch):
    X = torch.stack([torch.tensor(x["X"], dtype=torch.float32) for x in batch])
    y = torch.tensor([x["y"] for x in batch], dtype=torch.float32)
    return X, y, batch[0]["file"]


In [5]:
class Model(nn.Module):
    def __init__(self, input_dim) -> None:
        super().__init__()
        self.linear1 = nn.Linear(input_dim, 64)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(64, 16)
        self.relu2 = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        self.classifier = nn.Linear(16, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)

        x = self.classifier(x)
        return x

In [19]:
import os
from tqdm import tqdm
from loss import PairwiseHingeLoss
from scipy.stats import kendalltau
from collections import defaultdict

# example data
INPUT_DIM = 660

output_dir = "/home/thanh/google_fast_or_slow/outputs_embeddings/outputs_embeddings"

all_files = [os.path.join(output_dir, file) for file in os.listdir(output_dir) if "_valid_" in file]
file_to_features = defaultdict(list)
file_to_target = defaultdict(list)
for file in all_files:
    data = np.load(file, allow_pickle=True).item()
    for graph, emb, target in zip(data["files"], data["embeddings"], data["gts"]):
        file_to_features[graph].append(emb)
        file_to_target[graph] = target

all_graphs = list(file_to_features.keys())
test_index = np.random.choice(len(all_graphs), 1)[0]
test_graphs = [all_graphs[test_index]]
train_graphs = [graph for graph in all_graphs if graph not in test_graphs]

# train_graphs = all_graphs[:-1]
# test_graphs = all_graphs[-1:]

print(test_graphs)

train_df = pd.DataFrame.from_dict(
    {
        "file": train_graphs,
        "features": [np.concatenate(file_to_features[file], axis=1) for file in train_graphs],
        "target": [file_to_target[file] for file in train_graphs],
    }
)

test_df = pd.DataFrame.from_dict(
    {
        "file": test_graphs,
        "features": [np.concatenate(file_to_features[file], axis=1) for file in test_graphs],
        "target": [file_to_target[file] for file in test_graphs],
    }
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for CHUNK_SIZE in [128, 256, 512]:
    train_dataset = EnsembleDataset(train_df, "train", CHUNK_SIZE)
    test_dataset = EnsembleDataset(test_df, "test", CHUNK_SIZE)

    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    model = Model(INPUT_DIM).to(device)

    n_epochs = 1000
    lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    criterion = PairwiseHingeLoss()

    best_score = 0
    for epoch in tqdm(range(n_epochs)):
        model.train()
        train_loss = 0
        for batch in train_dataloader:
            optimizer.zero_grad()
            x, y, file = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y, n=torch.tensor([y_pred.shape[1]], device=y_pred.device))
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            scheduler.step()

        model.eval()

        file_preds = defaultdict(list)
        file_gts = defaultdict(list)
        test_loss = 0
        for batch in val_dataloader:
            x, y, file = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y, n=torch.tensor([y_pred.shape[1]], device=y_pred.device))
            test_loss += loss.item()
            file_preds[file].extend(y_pred[0].detach().cpu().numpy().tolist())
            file_gts[file].extend(y[0].detach().cpu().numpy().tolist())
        
        # calculate metric for each file
        all_taus = []
        for file in file_preds:
            tau, _ = kendalltau(file_preds[file], np.array(file_gts[file]).reshape(-1))
            # print(f"File {file} has tau {tau}")
            all_taus.append(tau)

        if np.mean(all_taus) > best_score:
            best_score = np.mean(all_taus)
            print(f"""
                Epoch: {epoch} \t Train loss: {train_loss / len(train_dataloader)} \t Test loss: {test_loss / len(val_dataloader)} \t Average tau: {np.mean(all_taus)}
                --------------------------------------------------------
            """)
            torch.save(model.state_dict(), f"mlp_ensemble_{CHUNK_SIZE}.pt")

    print(f"Best score: {best_score} with chunk size {CHUNK_SIZE}")

['resnet50.4x4.fp16.npz']


  0%|          | 2/1000 [00:00<01:57,  8.51it/s]


                Epoch: 0 	 Train loss: 0.3582819551229477 	 Test loss: 0.22334577420423196 	 Average tau: 0.5958446290200594
                --------------------------------------------------------
            

                Epoch: 1 	 Train loss: 0.3307611793279648 	 Test loss: 0.22297775641430256 	 Average tau: 0.5959503618594728
                --------------------------------------------------------
            


  0%|          | 4/1000 [00:00<01:58,  8.38it/s]


                Epoch: 2 	 Train loss: 0.2865614891052246 	 Test loss: 0.2237213375263436 	 Average tau: 0.5961204888621219
                --------------------------------------------------------
            


 68%|██████▊   | 682/1000 [01:21<00:36,  8.73it/s]


                Epoch: 680 	 Train loss: 0.2689455027381579 	 Test loss: 0.24727154228576395 	 Average tau: 0.6434016670802672
                --------------------------------------------------------
            


 70%|███████   | 704/1000 [01:23<00:33,  8.88it/s]


                Epoch: 702 	 Train loss: 0.2539753665526708 	 Test loss: 0.25347705774529034 	 Average tau: 0.6543364989149754
                --------------------------------------------------------
            

                Epoch: 703 	 Train loss: 0.25937436024347943 	 Test loss: 0.24723460196062577 	 Average tau: 0.6685551172107206
                --------------------------------------------------------
            


100%|██████████| 1000/1000 [02:01<00:00,  8.25it/s]


Best score: 0.6685551172107206 with chunk size 128


  0%|          | 2/1000 [00:00<01:16, 12.97it/s]


                Epoch: 0 	 Train loss: 0.35828570524851483 	 Test loss: 0.2226115188428334 	 Average tau: 0.5929406350861406
                --------------------------------------------------------
            

                Epoch: 1 	 Train loss: 0.3210902710755666 	 Test loss: 0.224845324243818 	 Average tau: 0.593336859655092
                --------------------------------------------------------
            

                Epoch: 2 	 Train loss: 0.3062420388062795 	 Test loss: 0.22458847221874056 	 Average tau: 0.5938430074777326
                --------------------------------------------------------
            


  0%|          | 4/1000 [00:00<01:16, 13.00it/s]


                Epoch: 3 	 Train loss: 0.290676087141037 	 Test loss: 0.22475089771406992 	 Average tau: 0.5941051696991585
                --------------------------------------------------------
            

                Epoch: 4 	 Train loss: 0.3041311129927635 	 Test loss: 0.22444704884574526 	 Average tau: 0.5942604411980086
                --------------------------------------------------------
            


  1%|          | 6/1000 [00:00<01:16, 13.08it/s]


                Epoch: 5 	 Train loss: 0.2877661883831024 	 Test loss: 0.22200242607366472 	 Average tau: 0.5947090100591115
                --------------------------------------------------------
            


  1%|          | 8/1000 [00:00<01:16, 12.98it/s]


                Epoch: 6 	 Train loss: 0.292085940639178 	 Test loss: 0.22173188910597846 	 Average tau: 0.5947317497729152
                --------------------------------------------------------
            

                Epoch: 8 	 Train loss: 0.2947072933117549 	 Test loss: 0.22274166274638402 	 Average tau: 0.5951768879682547
                --------------------------------------------------------
            


  1%|          | 10/1000 [00:00<01:16, 12.99it/s]


                Epoch: 9 	 Train loss: 0.28381893038749695 	 Test loss: 0.22229752512205214 	 Average tau: 0.5952119262692984
                --------------------------------------------------------
            


 68%|██████▊   | 678/1000 [00:52<00:24, 13.36it/s]


                Epoch: 678 	 Train loss: 0.25516920288403827 	 Test loss: 0.25358633129369645 	 Average tau: 0.62070914498156
                --------------------------------------------------------
            


 68%|██████▊   | 683/1000 [00:53<00:59,  5.35it/s]


                Epoch: 682 	 Train loss: 0.2601294244329135 	 Test loss: 0.233894767505782 	 Average tau: 0.665873005317078
                --------------------------------------------------------
            


 71%|███████   | 706/1000 [00:57<00:26, 10.95it/s]


                Epoch: 703 	 Train loss: 0.2779843658208847 	 Test loss: 0.2482170377458845 	 Average tau: 0.6724960706696206
                --------------------------------------------------------
            


100%|██████████| 1000/1000 [01:20<00:00, 12.46it/s]


Best score: 0.6724960706696206 with chunk size 256


  0%|          | 2/1000 [00:00<00:58, 17.02it/s]


                Epoch: 0 	 Train loss: 0.34731827427943546 	 Test loss: 0.22234828472137452 	 Average tau: 0.595575918039249
                --------------------------------------------------------
            

                Epoch: 2 	 Train loss: 0.2982201874256134 	 Test loss: 0.2242656856775284 	 Average tau: 0.5958582026175162
                --------------------------------------------------------
            


 64%|██████▍   | 640/1000 [00:41<00:21, 17.08it/s]


                Epoch: 636 	 Train loss: 0.26901625593503314 	 Test loss: 0.23280747085809708 	 Average tau: 0.6241021981465269
                --------------------------------------------------------
            


 65%|██████▍   | 648/1000 [00:42<00:23, 15.18it/s]


                Epoch: 645 	 Train loss: 0.27192391951878864 	 Test loss: 0.22994683235883712 	 Average tau: 0.6242464921836248
                --------------------------------------------------------
            

                Epoch: 647 	 Train loss: 0.2654834936062495 	 Test loss: 0.24934960454702376 	 Average tau: 0.6584910242572629
                --------------------------------------------------------
            


 66%|██████▌   | 656/1000 [00:42<00:20, 16.74it/s]


                Epoch: 653 	 Train loss: 0.27116520206133526 	 Test loss: 0.24884019494056703 	 Average tau: 0.6823647768518786
                --------------------------------------------------------
            


100%|██████████| 1000/1000 [01:03<00:00, 15.84it/s]

Best score: 0.6823647768518786 with chunk size 512





In [20]:
# inference
INFERENCE_CHUNK_SIZE = 512
model.load_state_dict(torch.load(f"mlp_ensemble_{INFERENCE_CHUNK_SIZE}.pt"))
output_dir = "/home/thanh/google_fast_or_slow/outputs_embeddings/outputs_embeddings"

all_files = [os.path.join(output_dir, file) for file in os.listdir(output_dir) if "_test_" in file]
file_to_features = defaultdict(list)
file_to_target = defaultdict(list)

for file in all_files:
    data = np.load(file, allow_pickle=True).item()
    for graph, emb, target in zip(data["files"], data["embeddings"], data["gts"]):
        file_to_features[graph].append(emb)
        file_to_target[graph] = target

all_graphs = list(file_to_features.keys())

inference_df = pd.DataFrame.from_dict(
    {
        "file": all_graphs,
        "features": [np.concatenate(file_to_features[file], axis=1) for file in all_graphs],
        "target": [file_to_target[file] for file in all_graphs],
    }
)

inference_dataset = EnsembleDataset(inference_df, "test", INFERENCE_CHUNK_SIZE)
inference_dataloader = DataLoader(inference_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

model.eval()

file_preds = defaultdict(list)

for batch in inference_dataloader:
    x, y, file = batch
    x, y = x.to(device), y.to(device)
    y_pred = model(x)
    file_preds[file].extend(y_pred[0].detach().cpu().numpy().tolist())


In [25]:
top_configs = []

for file in file_preds.keys():
    top_configs.append(np.array(file_preds[file]).reshape(-1).argsort())

prediction_df = pd.DataFrame.from_dict(
    {
        "ID": ["layout:xla:default:" + f.split(".")[0] for f in list(file_preds.keys())],
        "TopConfigs": [";".join([str(e) for e in top_configs[i]]) for i in range(len(top_configs))],
    }
)
prediction_df

Unnamed: 0,ID,TopConfigs
0,layout:xla:default:05ae41e26dd3c4c06390371a042...,730;585;874;498;787;197;709;345;244;232;323;36...
1,layout:xla:default:3e7156ac468dfb75cf5c9615e1e...,93;794;541;178;234;39;853;948;275;357;245;75;1...
2,layout:xla:default:5335ed13823b0a518ee3c79ba44...,757;395;337;522;916;178;958;779;179;540;687;12...
3,layout:xla:default:937ee0eb0d5d6151b7b8252933b...,85;132;93;757;107;163;991;145;850;786;125;51;8...
4,layout:xla:default:cd708819d3f5103afd6460b15e7...,366;245;153;310;711;848;111;313;256;689;232;33...
5,layout:xla:default:db59a991b7c607634f13570d52c...,425;411;729;620;367;286;571;788;379;59;899;658...
6,layout:xla:default:e8a3a1401b5e79f66d7037e424f...,428;310;625;920;731;216;619;217;153;879;52;415...
7,layout:xla:default:fbaa8bb6a1aed9988281085c910...,668;224;881;242;222;546;726;924;658;859;984;16...


In [26]:
prediction_df.to_csv("outputs_csv/mlp_ensemble.csv", index=False)