In [2]:
import import_ipynb
import gc
from CORAL_Methods_and_Classes import *

import os    
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [3]:
# set hyperparams here
def Train_SIREN_INR(train_dataset, test_dataset, latent_dim, filename, epochs=100):

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=16)

    # Ensure device compatibility
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Instantiate the Modulated SIREN INR
    input_dim = 2
    hidden_dim = 256
    output_dim = 1
    num_layers = 4

    # params
    inner_steps = 3
    inner_lr = 1e-5
    lr_code = 1e-5
    meta_lr_code = 1e-5
    weight_decay_code = 0
    lr_inr = 1e-4
    gamma_step = 0.5
    ntrain = len(train_dataset)
    ntest = len(test_dataset)
    epochs = epochs
    RESULTS_DIR = './INR_models_trained/'
    run_name = filename

    # instantiate ModulatedSiren model with hyperparams
    inr = ModulatedSiren(
      dim_in=input_dim,
      dim_hidden=hidden_dim,
      dim_out=output_dim,
      num_layers=num_layers,
      latent_dim=latent_dim,
      modulation_net_dim_hidden=64,
      modulation_net_num_layers=2,
      modulate_scale=False,
      modulate_shift=True,
    ).to(device)

    # run training loop, which saves the best model
    train_inr(
      inr=inr,
      train_loader=train_loader,
      test_loader=test_loader,
      latent_dim=latent_dim,
      inner_steps=inner_steps,
      inner_lr=inner_lr,
      lr_code=lr_code,
      meta_lr_code=meta_lr_code,
      weight_decay_code=weight_decay_code,
      lr_inr=lr_inr,
      gamma_step=gamma_step,
      ntrain=ntrain,
      ntest=ntest,
      epochs=epochs,
      saved_checkpoint=False,
      checkpoint=None,
      RESULTS_DIR=RESULTS_DIR,
      run_name=run_name,
      device=device,
    )

In [154]:
# load files
import pickle
# Load in data
shape_list = ['triangle', 'bean', 'torus1', 'torus2']
percent_list = ['0.1', '0.5']

for shape in shape_list:
    for percent in percent_list:
        
        print(shape, percent, '-'*100)

        with open(f'./datasets_generated/dataset_geometry_percentage/dataset_{shape}_{percent}_1200.pkl', 'rb') as file:
            generated_dataset = pickle.load(file)

        # at 20% sampling

        coordinates, a_features, u_features, a_u_pairs = generated_dataset

        # scale up u features to avoid numerical errors
        u_features = u_features*10000

        # ensure all the same dtype
        coordinates = coordinates.to(torch.float32)
        a_features = a_features.to(torch.float32)
        u_features = u_features.to(torch.float32)

        # split by train and test

        N_train = 1000
        N_test = 200
        N_val = 0

        coordinates_train = coordinates[:N_train]
        a_features_train = a_features[:N_train]
        u_features_train = u_features[:N_train]

        # used in INR and MLP training
        coordinates_test = coordinates[N_train:N_train+N_test]
        a_features_test = a_features[N_train:N_train+N_test]
        u_features_test = u_features[N_train:N_train+N_test]

        # unseen data
        coordinates_val = coordinates[N_train+N_test:N_train+N_test+N_val]
        a_features_val = a_features[N_train+N_test:N_train+N_test+N_val]
        u_features_val = u_features[N_train+N_test:N_train+N_test+N_val]

        # specify dimension of code
        latent_dim = 128

        # use torch Dataset format
        a_train_dataset = FunctionDataset(coordinates_train, a_features_train, latent_dim)
        a_test_dataset = FunctionDataset(coordinates_test, a_features_test, latent_dim)
        a_val_dataset = FunctionDataset(coordinates_val, a_features_val, latent_dim)

        u_train_dataset = FunctionDataset(coordinates_train, u_features_train, latent_dim)
        u_test_dataset = FunctionDataset(coordinates_test, u_features_test, latent_dim)
        u_val_dataset = FunctionDataset(coordinates_val, u_features_val, latent_dim)

        # Train the input INR

        Train_SIREN_INR(a_train_dataset, a_test_dataset, latent_dim, f'{shape}_input_{percent}_SIREN_INR', epochs=50)

        # Train the output INR

        Train_SIREN_INR(u_train_dataset, u_test_dataset, latent_dim, f'{shape}_output_{percent}_SIREN_INR', epochs=50)

triangle 0.1 ----------------------------------------------------------------------------------------------------
epoch = 0; train_loss = 60.58953491210937; test_loss=51.46541198730469
save model
epoch = 1; train_loss = 42.16449749755859; test_loss=35.97215118408203
save model
epoch = 2; train_loss = 29.696978454589843; test_loss=25.817020568847656
save model
epoch = 3; train_loss = 21.11116795349121; test_loss=17.11549301147461
save model
epoch = 4; train_loss = 13.067118400573731; test_loss=9.693951683044434
save model
epoch = 5; train_loss = 6.9820072517395015; test_loss=4.712105655670166
save model
epoch = 6; train_loss = 3.4822229747772218; test_loss=2.5702927923202514
save model
epoch = 7; train_loss = 2.270056447982788; test_loss=1.993308620452881
save model
epoch = 8; train_loss = 1.9652873163223266; test_loss=1.9833534622192384
save model
epoch = 9; train_loss = 1.826489981651306; test_loss=1.739971833229065
save model
epoch = 10; train_loss = 1.7072689933776855; test_loss=1.6

epoch = 0; train_loss = 59.98656079101563; test_loss=46.193968048095705
save model
epoch = 1; train_loss = 41.784460662841795; test_loss=32.340176849365236
save model
epoch = 2; train_loss = 29.549157012939453; test_loss=23.646775817871095
save model
epoch = 3; train_loss = 20.64802149963379; test_loss=14.715296516418457
save model
epoch = 4; train_loss = 12.622752128601075; test_loss=8.190580558776855
save model
epoch = 5; train_loss = 6.853143733978271; test_loss=3.979929218292236
save model
epoch = 6; train_loss = 3.5534858226776125; test_loss=2.281143236160278
save model
epoch = 7; train_loss = 2.7270619859695433; test_loss=2.016490840911865
save model
epoch = 8; train_loss = 2.405266471862793; test_loss=1.9027437114715575
save model
epoch = 9; train_loss = 2.279712173461914; test_loss=1.8413267755508422
save model
epoch = 10; train_loss = 2.2015649900436403; test_loss=1.8135796117782592
save model
epoch = 11; train_loss = 2.1712128982543946; test_loss=2.04972692489624
save model
e

epoch = 0; train_loss = 60.848320343017576; test_loss=50.258057861328126
save model
epoch = 1; train_loss = 42.37622311401367; test_loss=35.20178634643555
save model
epoch = 2; train_loss = 29.752866088867187; test_loss=25.26053939819336
save model
epoch = 3; train_loss = 20.608127990722657; test_loss=16.135546798706056
save model
epoch = 4; train_loss = 12.437546264648438; test_loss=9.100267066955567
save model
epoch = 5; train_loss = 6.563810497283936; test_loss=4.448033752441407
save model
epoch = 6; train_loss = 3.193362112045288; test_loss=2.449841241836548
save model
epoch = 7; train_loss = 2.1281088247299196; test_loss=2.0066148233413696
save model
epoch = 8; train_loss = 1.8868294801712036; test_loss=1.9697113513946534
save model
epoch = 9; train_loss = 2.0570208463668824; test_loss=2.003126893043518
epoch = 10; train_loss = 1.8820096130371093; test_loss=1.8243400573730468
save model
epoch = 11; train_loss = 2.217888611793518; test_loss=2.332285943031311
epoch = 12; train_loss 

epoch = 0; train_loss = 59.6065329284668; test_loss=47.63971649169922
save model
epoch = 1; train_loss = 41.486289154052734; test_loss=33.301532135009765
save model
epoch = 2; train_loss = 29.332801208496093; test_loss=24.36581619262695
save model
epoch = 3; train_loss = 20.10162289428711; test_loss=15.576829299926757
save model
epoch = 4; train_loss = 12.410514129638672; test_loss=8.681437339782715
save model
epoch = 5; train_loss = 6.550437118530273; test_loss=4.279465579986573
save model
epoch = 6; train_loss = 3.4313197994232176; test_loss=2.4919570159912108
save model
epoch = 7; train_loss = 2.4725820579528808; test_loss=2.1506924533843996
save model
epoch = 8; train_loss = 2.4108193264007567; test_loss=2.1661732006073
save model
epoch = 9; train_loss = 2.1804898452758787; test_loss=1.9580448055267334
save model
epoch = 10; train_loss = 2.05248148727417; test_loss=1.8910607814788818
save model
epoch = 11; train_loss = 1.9987102222442628; test_loss=1.899084405899048
save model
epoc

epoch = 49; train_loss = 0.1456159714460373; test_loss=0.1639876753091812
torus1 0.1 ----------------------------------------------------------------------------------------------------
epoch = 0; train_loss = 58.186902801513675; test_loss=45.52058959960937
save model
epoch = 1; train_loss = 40.33847360229492; test_loss=31.802446212768555
save model
epoch = 2; train_loss = 28.52710870361328; test_loss=23.3787060546875
save model
epoch = 3; train_loss = 20.637112152099608; test_loss=15.216281356811523
save model
epoch = 4; train_loss = 12.771727569580078; test_loss=9.273172073364258
save model
epoch = 5; train_loss = 7.243203647613526; test_loss=4.767129859924316
save model
epoch = 6; train_loss = 3.8184190101623536; test_loss=2.719597158432007
save model
epoch = 7; train_loss = 2.615072235107422; test_loss=2.284588499069214
save model
epoch = 8; train_loss = 2.413914192199707; test_loss=2.1121665477752685
save model
epoch = 9; train_loss = 2.1609542236328125; test_loss=2.05549796104431

epoch = 48; train_loss = 0.11826732945442199; test_loss=0.15481894493103027
save model
epoch = 49; train_loss = 0.11300837290287018; test_loss=0.15728315383195876
save model
torus1 0.5 ----------------------------------------------------------------------------------------------------
epoch = 0; train_loss = 59.98885037231445; test_loss=51.86405166625977
save model
epoch = 1; train_loss = 41.67302020263672; test_loss=36.27059799194336
save model
epoch = 2; train_loss = 29.373047897338868; test_loss=26.043401184082033
save model
epoch = 3; train_loss = 20.90655107116699; test_loss=17.194327583312987
save model
epoch = 4; train_loss = 13.346044746398926; test_loss=10.586718406677246
save model
epoch = 5; train_loss = 7.546941207885742; test_loss=5.351316251754761
save model
epoch = 6; train_loss = 3.88597598361969; test_loss=3.0585909414291383
save model
epoch = 7; train_loss = 2.729946400642395; test_loss=2.5670898389816283
save model
epoch = 8; train_loss = 2.4348588581085204; test_los

epoch = 46; train_loss = 0.13763370537757874; test_loss=0.18007359862327577
epoch = 47; train_loss = 0.13806975346803665; test_loss=0.17531623661518098
epoch = 48; train_loss = 0.13402389669418335; test_loss=0.17373304009437562
save model
epoch = 49; train_loss = 0.13062801629304885; test_loss=0.1693816041946411
save model
torus2 0.1 ----------------------------------------------------------------------------------------------------
epoch = 0; train_loss = 61.062154083251954; test_loss=51.32860229492187
save model
epoch = 1; train_loss = 42.467556030273435; test_loss=35.85173034667969
save model
epoch = 2; train_loss = 29.940530319213867; test_loss=25.656807861328126
save model
epoch = 3; train_loss = 21.59395248413086; test_loss=17.173040084838867
save model
epoch = 4; train_loss = 13.76871469116211; test_loss=10.94708335876465
save model
epoch = 5; train_loss = 8.064839023590087; test_loss=5.443332653045655
save model
epoch = 6; train_loss = 3.8620230712890624; test_loss=2.7680653858

epoch = 45; train_loss = 0.20362141239643097; test_loss=0.21606138706207276
epoch = 46; train_loss = 0.19329860317707062; test_loss=0.21297238051891326
save model
epoch = 47; train_loss = 0.18091903150081634; test_loss=0.19220730125904084
save model
epoch = 48; train_loss = 0.17405987334251405; test_loss=0.19708967506885527
save model
epoch = 49; train_loss = 0.16984657239913942; test_loss=0.19005816102027892
save model
torus2 0.5 ----------------------------------------------------------------------------------------------------
epoch = 0; train_loss = 62.523958190917966; test_loss=51.53033554077148
save model
epoch = 1; train_loss = 43.5497818145752; test_loss=36.00705307006836
save model
epoch = 2; train_loss = 30.644112884521483; test_loss=25.77215835571289
save model
epoch = 3; train_loss = 22.214628662109376; test_loss=17.212132568359376
save model
epoch = 4; train_loss = 14.210256633758545; test_loss=10.68941749572754
save model
epoch = 5; train_loss = 7.977547040939331; test_lo

epoch = 44; train_loss = 0.21436585927009583; test_loss=0.24203043699264526
save model
epoch = 45; train_loss = 0.21065219950675965; test_loss=0.23612891793251037
save model
epoch = 46; train_loss = 0.20425519359111785; test_loss=0.22381652116775513
save model
epoch = 47; train_loss = 0.1937833170890808; test_loss=0.21141630709171294
save model
epoch = 48; train_loss = 0.18874652922153473; test_loss=0.21617799162864684
save model
epoch = 49; train_loss = 0.1877949196100235; test_loss=0.204817174077034
save model


In [4]:
import pickle

shape = 'torus2'
percent = '0.5'

with open(f'./datasets_generated/dataset_geometry_percentage/dataset_{shape}_{percent}_1200.pkl', 'rb') as file:
            generated_dataset = pickle.load(file)

coordinates, a_features, u_features, a_u_pairs = generated_dataset

# scale up u features to avoid numerical errors
u_features = u_features*10000

# ensure all the same dtype
coordinates = coordinates.to(torch.float32)
a_features = a_features.to(torch.float32)
u_features = u_features.to(torch.float32)

# split by train and test

N_train = 1000
N_test = 200
N_val = 0

coordinates_train = coordinates[:N_train]
a_features_train = a_features[:N_train]
u_features_train = u_features[:N_train]

# used in INR and MLP training
coordinates_test = coordinates[N_train:N_train+N_test]
a_features_test = a_features[N_train:N_train+N_test]
u_features_test = u_features[N_train:N_train+N_test]

# unseen data
coordinates_val = coordinates[N_train+N_test:N_train+N_test+N_val]
a_features_val = a_features[N_train+N_test:N_train+N_test+N_val]
u_features_val = u_features[N_train+N_test:N_train+N_test+N_val]

# specify dimension of code
latent_dim = 128

# use torch Dataset format
a_train_dataset = FunctionDataset(coordinates_train, a_features_train, latent_dim)
a_test_dataset = FunctionDataset(coordinates_test, a_features_test, latent_dim)
a_val_dataset = FunctionDataset(coordinates_val, a_features_val, latent_dim)

u_train_dataset = FunctionDataset(coordinates_train, u_features_train, latent_dim)
u_test_dataset = FunctionDataset(coordinates_test, u_features_test, latent_dim)
u_val_dataset = FunctionDataset(coordinates_val, u_features_val, latent_dim)


In [None]:
# extract the encodings by using the inner loop 
import torch
input_INR_name = f'./INR_models_trained/{shape}_input_{percent}_SIREN_INR.pt'

z_train_a = ExtractEncoding(input_INR_name, a_train_dataset)
z_test_a = ExtractEncoding(input_INR_name, a_test_dataset)
#z_val_a = ExtractEncoding(input_INR_name, a_val_dataset)



In [None]:
output_INR_name = f'./INR_models_trained/{shape}_output_{percent}_SIREN_INR.pt'

z_train_u = ExtractEncoding(output_INR_name, u_train_dataset)
z_test_u = ExtractEncoding(output_INR_name, u_test_dataset)
#z_val_u = ExtractEncoding(output_INR_name, u_val_dataset)

In [None]:
# extract statistics from train_data for z-score normalisation
z_train_a_mean = z_train_a.mean(axis=0)
z_train_a_std = z_train_a.std(axis=0)

z_train_a_normalised = (z_train_a - z_train_a_mean)/z_train_a_std
z_test_a_normalised = (z_test_a - z_train_a_mean)/z_train_a_std
#z_val_a_normalised = (z_val_a - z_train_a_mean)/z_train_a_std


z_train_u_mean = z_train_u.mean(axis=0)
z_train_u_std = z_train_u.std(axis=0)

z_train_u_normalised = (z_train_u - z_train_u_mean)/z_train_u_std
z_test_u_normalised = (z_test_u - z_train_u_mean)/z_train_u_std
#z_val_u_normalised = (z_val_u - z_train_u_mean)/z_train_u_std

In [None]:
batch_size = 32

model = LatentCodeMLP(latent_dim=latent_dim, hidden_dim=64, num_layers=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trained_model, train_losses, test_losses = train_latent_code_mlp(
    model=model,
    dataset_a_train=z_train_a_normalised,
    dataset_u_train=z_train_u_normalised,
    dataset_a_test=z_test_a_normalised,
    dataset_u_test=z_test_u_normalised,
    latent_dim=latent_dim,
    epochs=500,
    lr=1e-4,
    batch_size=batch_size,
    device=device
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Testing Loss")
plt.show()


In [None]:
trained_model, train_losses, test_losses = train_latent_code_mlp(
    model=trained_model,
    dataset_a_train=z_train_a_normalised,
    dataset_u_train=z_train_u_normalised,
    dataset_a_test=z_test_a_normalised,
    dataset_u_test=z_test_u_normalised,
    latent_dim=latent_dim,
    epochs=500,
    lr=1e-4,
    batch_size=batch_size,
    device=device
)

In [None]:
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Testing Loss")
plt.show()


# converting transformed latent into output function (inference)

In [None]:
train_transformed_z_a = (trained_model(z_train_a_normalised)*z_train_u_std) + z_train_u_mean # approximation of z_u train
test_transformed_z_a = (trained_model(z_test_a_normalised)*z_train_u_std) + z_train_u_mean # approximation of z_u test
#val_transformed_z_a = (trained_model(z_val_a_normalised)*z_train_u_std) + z_train_u_mean # approximation of z_u test

# these are now the codes for the INR which will make the output function

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# EXTRACT THE CODES FROM INPUT INR

best_model = torch.load(output_INR_name)

siren_model = best_model['inr']

In [None]:
from scipy.interpolate import griddata
import matplotlib.pyplot as plt


def VisualisePredictions(i, plots=True):
    
    N=64
    
    # Create 64 equally spaced points in the range [0, 1]
    grid_points = torch.linspace(0, 1, N)

    # Create the meshgrid
    x, y = torch.meshgrid(grid_points, grid_points, indexing='ij')

    # Stack the x and y coordinates along a new dimension
    grid = torch.stack((x, y), dim=-1)
    mask = mask_generator(N, shape).astype(bool)
    
    with torch.no_grad():
        # Forward pass through the SIREN model
        modulations = test_transformed_z_a[i].unsqueeze(0)  # Add batch dimension

        # generate z from the model
        z_predictions = siren_model.modulated_forward(grid, modulations).squeeze().detach()*mask
        z_predictions = z_predictions.numpy()

    #--------------------------------------------------------------

    z_ground_u = 10000*a_u_pairs[1000+i][1]
    z_ground_a = a_u_pairs[1000+i][0]
    
    rel_error = abs((z_ground_u-z_predictions)/z_ground_u)
    mean_rel_error = np.mean(rel_error[~np.isnan(rel_error)])

    if plots:
        # Plot true vs predicted
        plt.figure(figsize=(18, 6))

        plt.subplot(1, 3, 1)
        plt.title("Ground Truth")
        plt.imshow(z_ground_u, cmap="viridis", extent=(0, 64, 0, 64))
        plt.colorbar(label="Value")

        plt.subplot(1, 3, 2)
        plt.title("Prediction")
        plt.imshow(z_predictions, cmap="viridis", extent=(0, 64, 0, 64))
        plt.colorbar(label="Value")

        plt.subplot(1, 3, 3)
        plt.title("rel. error")
        plt.imshow(rel_error, cmap="viridis", extent=(0, 64, 0, 64))
        plt.colorbar(label="Value")

        plt.tight_layout()
        plt.show()
        
        print(mean_rel_error)
    
    return mean_rel_error


In [None]:
mean_rel_errors = []

for i in range(200):
    print(i)
    mean_rel_errors.append(VisualisePredictions(i, plots=False))

In [None]:
print(shape, percent)

print(np.mean(mean_rel_errors))
print(np.median(mean_rel_errors))

In [None]:

plt.hist(mean_rel_errors, bins=100)