In [None]:
%cd ./../

In [None]:
import logging
import os

from dotenv import load_dotenv

load_dotenv()
os.environ['INDEX'] = str(0)

# TODO: share the code with the `logs.py` file.

logging.basicConfig(
	format="{name}\t{asctime}\t{levelname}\t{message}\t",
	style='{',
	level=logging.DEBUG
)
logging.getLogger('PIL.PngImagePlugin').disabled = True
logging.getLogger('matplotlib').disabled = True
logging.getLogger('matplotlib.font_manager').disabled = True

In [None]:
import repro  # Imported for its side effects.

In [None]:
from data import train_data_loader, test_data_loader, train_eval_data_loader
from model import model
import conf, conf_eval
from torch.nn import CrossEntropyLoss
import optim
from train import TorchContext

criterion = CrossEntropyLoss()  # TODO: move into a module to share with the `main` module.
optimizer = optim.init()

# TODO: move into a module to share with the `main` module.
torch_context = TorchContext(
	model=model,
	optimizer=optimizer,
	criterion=criterion,
	train_data_loader=train_data_loader,
	test_data_loader=test_data_loader,
	train_eval_data_loader=train_eval_data_loader,
	device=conf_eval.DEVICE_EVAL
)

In [None]:
from tabulate import tabulate

epochs_metrics = await torch_context.train(epochs=conf.FIRST_ROUND_TRAIN_EPOCHS)

epochs_metrics = [(corrects / total, losses_sum / total) for (corrects, total, losses_sum) in epochs_metrics]
accuracies, losses = [accuracy for accuracy, _ in epochs_metrics], [loss for _, loss in epochs_metrics]

print(tabulate({'epoch': range(len(epochs_metrics)), 'accuracy': accuracies, 'loss': losses}, headers='keys', tablefmt='outline'))

In [None]:
import search_data

base_vec = search_data.params(model, cpu=False)

In [None]:
import torch
from similarity import cosine
from exts.torch_exts import topp_mag

ps_sims = []

for p in torch.linspace(0, 1, 101):
	p = float(p)

	pruned_vec = base_vec.clone()
	pruned_vec = topp_mag(pruned_vec, p=1 - p)

	ps_sims.append((p, cosine(base_vec, pruned_vec)))

In [None]:
import pandas as pd

df = pd.DataFrame(ps_sims, columns=['p', 'sim'])

utopia = torch.Tensor((1, 1))
df['dist'] = df.apply(lambda r: torch.dist(torch.Tensor((r.sim, r.p)), utopia), axis='columns')

df

In [None]:
import plotly.express as px

optimal_point = df.loc[df['dist'].idxmin()]
print(f"Optimal Point: p={optimal_point['p']}, sim={optimal_point['sim']}.")

fig = px.line(
	df,
	x='p', y='sim',
	title='Pareto Front', labels={'p': 'Pruning Ratio', 'sim': 'Cosine Similarity'}, markers=True
)

fig.add_scatter(x=(1,), y=(1,), name='Utopia', mode='markers', marker=dict(color='red', size=10, symbol='star'))
fig.add_scatter(x=(optimal_point['p'],), y=(optimal_point['sim'],), mode='markers', marker=dict(color='gray', size=10), name='Optima')

# fig.update_layout(yaxis_scaleanchor='x')  # 1:1 aspect ratio.

fig.show()