In [35]:
import os
import re
import torch
from tqdm import tqdm
from collections import defaultdict
from torchvision import models, transforms
from PIL import Image
import pandas as pd
import numpy as np
import xgboost as xgb

import torch.nn as nn
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, TimesformerForVideoClassification, AutoImageProcessor
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler

device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")

## Prepare Data

In [2]:
def build_tag_vocab(tag_lists, min_freq=1):
    tag_freq = defaultdict(int)
    for tags in tag_lists:
        for tag in tags:
            tag_freq[tag.lower()] += 1

    vocab = {'[PAD]': 0, '[UNK]': 1}
    for tag, freq in tag_freq.items():
        if freq >= min_freq:
            vocab[tag] = len(vocab)

    return vocab

df_train = pd.read_csv('movie_data_train.csv')

df_train['title_overview'] = df_train['original_title'] + ': ' + df_train['overview']

df_train = pd.DataFrame({
    'title_overview': df_train['title_overview'],
    'tags': df_train['tags'].fillna(''),
    'revenue': df_train['revenue']
})

df_test = pd.read_csv('movie_data_test.csv')

df_test['title_overview'] = df_test['original_title'] + ': ' + df_test['overview']

df_test = pd.DataFrame({
    'title_overview': df_test['title_overview'],
    'tags': df_test['tags'].fillna(''),
    'revenue': df_test['revenue']
})

df_train['revenue'] = np.log1p(df_train['revenue'])
df_test['revenue'] = np.log1p(df_test['revenue'])

df_train['tags'] = df_train['tags'].apply(lambda x: [tag.strip().lower() for tag in x.split(',') if tag.strip()])
df_test['tags'] = df_test['tags'].apply(lambda x: [tag.strip().lower() for tag in x.split(',') if tag.strip()])

train_texts = df_train['title_overview'].tolist()
train_tags = df_train['tags'].tolist()
train_targets = df_train['revenue'].tolist()

test_texts = df_test['title_overview'].tolist()
test_tags = df_test['tags'].tolist()
test_targets = df_test['revenue'].tolist()

tag_vocab = build_tag_vocab(train_tags + test_tags)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

## Prepare Title/Overview + Tag 2 Tower BERT Model

In [3]:
class TagCNNEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, num_filters=128, kernel_sizes=(2, 3, 4), dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, num_filters, k) for k in kernel_sizes
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        embedded = self.embedding(x).transpose(1, 2)
        conv_outs = [torch.relu(conv(embedded)).max(dim=2)[0] for conv in self.convs]
        out = torch.cat(conv_outs, dim=1)
        return self.dropout(out)

class BERTWithTagCNNRegressor(nn.Module):
    def __init__(self, tag_vocab_size, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.tag_encoder = TagCNNEncoder(tag_vocab_size)
        self.dropout = nn.Dropout(dropout)

        self.regressor = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size + 384, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, text_input_ids, text_attention_mask, tag_token_ids):

        bert_output = self.bert(input_ids=text_input_ids, attention_mask=text_attention_mask)
        text_cls = bert_output.pooler_output  

        tag_feat = self.tag_encoder(tag_token_ids)

        fused = torch.cat([text_cls, tag_feat], dim=1)
        return self.regressor(self.dropout(fused))
    
class MovieDatasetWithTags(nn.Module):
    def __init__(self, texts, tags, targets, tokenizer, tag_vocab, max_text_len=256, max_tag_len=20):
        self.texts = texts
        self.tags = tags
        self.targets = targets
        self.tokenizer = tokenizer
        self.tag_vocab = tag_vocab
        self.max_text_len = max_text_len
        self.max_tag_len = max_tag_len

    def __len__(self):
        return len(self.texts)

    def encode_tags(self, tag_list):
        tag_ids = [self.tag_vocab.get(tag.lower(), self.tag_vocab['[UNK]']) for tag in tag_list]
        tag_ids = tag_ids[:self.max_tag_len]
        tag_ids += [self.tag_vocab['[PAD]']] * (self.max_tag_len - len(tag_ids))
        return torch.tensor(tag_ids, dtype=torch.long)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tags = self.tags[idx]
        target = torch.tensor(self.targets[idx], dtype=torch.float)

        text_enc = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_text_len,
            return_tensors='pt'
        )

        tag_tensor = self.encode_tags(tags)

        return {
            'input_ids': text_enc['input_ids'].squeeze(0),
            'attention_mask': text_enc['attention_mask'].squeeze(0),
            'tags': tag_tensor,
            'target': target
        }

train_texts_split, val_texts, train_tags_split, val_tags, train_targets_split, val_targets = train_test_split(
    train_texts, train_tags, train_targets, test_size=0.2, random_state=42
)

train_dataset = MovieDatasetWithTags(train_texts_split, train_tags_split, train_targets_split, tokenizer, tag_vocab)
val_dataset = MovieDatasetWithTags(val_texts, val_tags, val_targets, tokenizer, tag_vocab)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

test_dataset = MovieDatasetWithTags(test_texts, test_tags, test_targets, tokenizer, tag_vocab)
test_loader = DataLoader(test_dataset, batch_size=16)

## Get Predictions of 2 Tower BERT Model

In [4]:
model = BERTWithTagCNNRegressor(tag_vocab_size=len(tag_vocab)).to(device)

criterion = nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=3e-5)
checkpoint = torch.load('models/title_overview_two_tower_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

title_train_predictions = []
train_actuals = []
title_val_predictions = []
val_actuals = []
title_test_predictions = []
test_actuals = []

with torch.no_grad():
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tags = batch['tags'].to(device)
        targets = batch['target'].cpu().numpy()
        
        outputs = model(input_ids, attention_mask, tags).squeeze().cpu().numpy()
        
        title_train_predictions.extend(outputs)
        train_actuals.extend(targets)

    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tags = batch['tags'].to(device)
        targets = batch['target'].cpu().numpy()
        
        outputs = model(input_ids, attention_mask, tags).squeeze().cpu().numpy()
        
        title_val_predictions.extend(outputs)
        val_actuals.extend(targets)

    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tags = batch['tags'].to(device)
        targets = batch['target'].cpu().numpy()

        outputs = model(input_ids, attention_mask, tags).squeeze().cpu().numpy()

        title_test_predictions.extend(outputs)
        test_actuals.extend(targets)
        
        

title_train_preds = np.expm1(title_train_predictions)
train_actuals = np.expm1(train_actuals)
title_val_preds = np.expm1(title_val_predictions)
val_actuals = np.expm1(val_actuals)
title_test_preds = np.expm1(title_test_predictions)
test_actuals = np.expm1(test_actuals)

## Setting Up Visual Ensemble Model

In [5]:
class MultiImageDataset(Dataset):
    def __init__(self, df, poster_dir, backdrop_dir, thumbnail_dir, transform):
        self.df = df
        self.poster_dir = poster_dir
        self.backdrop_dir = backdrop_dir
        self.thumbnail_dir = thumbnail_dir
        self.transform = transform
        self.valid_ids = []

        for idx, row in df.iterrows():
            movie_id = str(int(row['id']))
            if all(os.path.exists(os.path.join(d, f"{movie_id}.jpg")) for d in [poster_dir, backdrop_dir, thumbnail_dir]):
                self.valid_ids.append(idx)
            else:
                self.valid_ids.append(-1)                

    def __len__(self):
        return len(self.valid_ids)

    def __getitem__(self, idx):
        
        if self.valid_ids[idx] == -1:
            return {
                "poster": torch.zeros(3, 224, 224),
                "backdrop": torch.zeros(3, 224, 224),
                "thumbnail": torch.zeros(3, 224, 224),
                "revenue": torch.tensor(0, dtype=torch.float)
            }
        
        df_idx = self.valid_ids[idx]
        row = self.df.iloc[df_idx]
        movie_id = str(int(row['id']))
        revenue = np.log1p(row['revenue'])

        def load_image(directory):
            image = Image.open(os.path.join(directory, f"{movie_id}.jpg")).convert("RGB")
            return self.transform(image)

        return {
            "poster": load_image(self.poster_dir),
            "backdrop": load_image(self.backdrop_dir),
            "thumbnail": load_image(self.thumbnail_dir),
            "revenue": torch.tensor(revenue, dtype=torch.float)
        }

def get_resnet_backbone():
    resnet = models.resnet50(pretrained=True)
    for param in resnet.parameters():
        param.requires_grad = False
    for param in resnet.layer4.parameters():
        param.requires_grad = True
    for param in resnet.avgpool.parameters():
        param.requires_grad = True
    return nn.Sequential(*list(resnet.children())[:-1])

class FineTunedEnsemble(nn.Module):
    def __init__(self):
        super().__init__()
        self.poster_net = get_resnet_backbone()
        self.backdrop_net = get_resnet_backbone()
        self.thumbnail_net = get_resnet_backbone()

        self.mlp = nn.Sequential(
            nn.Linear(2048*3, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, poster, backdrop, thumbnail):
        p = self.poster_net(poster)
        b = self.backdrop_net(backdrop)
        t = self.thumbnail_net(thumbnail)

        x = torch.cat([p.view(p.size(0), -1), b.view(b.size(0), -1), t.view(t.size(0), -1)], dim=1)
        return self.mlp(x)

## Preparing Data For Visuals Ensemble Model

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

df_train = pd.read_csv("movie_data_train.csv")
df_test = pd.read_csv("movie_data_test.csv")

train_dataset = MultiImageDataset(df_train, "poster_dataset", "backdrop_dataset", "thumbnail_dataset", transform)
test_dataset = MultiImageDataset(df_test, "poster_dataset", "backdrop_dataset", "thumbnail_dataset", transform)

train_idx, val_idx = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)
train_subset = torch.utils.data.Subset(train_dataset, train_idx)
val_subset = torch.utils.data.Subset(train_dataset, val_idx)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

## Get Predictions from Visuals Ensemble Model

In [7]:
model = FineTunedEnsemble().to(device)
checkpoint = torch.load('models/best_ensemble_model.pt', map_location=device)
model.load_state_dict(checkpoint)
model.eval()
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

visuals_train_preds = []
visuals_val_preds = []
visuals_test_preds = []

with torch.no_grad():
    for batch in tqdm(train_loader, desc="Loading Training Predictions"):
        p = batch["poster"].to(device)
        b = batch["backdrop"].to(device)
        t = batch["thumbnail"].to(device)
        y = batch["revenue"].to(device)
        y_hat = model(p, b, t).squeeze()
        visuals_train_preds.extend(y_hat.cpu().view(-1).tolist())

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Loading Validation Predictions"):
        p = batch["poster"].to(device)
        b = batch["backdrop"].to(device)
        t = batch["thumbnail"].to(device)
        y = batch["revenue"].to(device)
        y_hat = model(p, b, t).squeeze()
        visuals_val_preds.extend(y_hat.cpu().view(-1).tolist())

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Loading Test Predictions"):
        p = batch["poster"].to(device)
        b = batch["backdrop"].to(device)
        t = batch["thumbnail"].to(device)
        y = batch["revenue"].to(device)
        y_hat = model(p, b, t).squeeze()
        visuals_test_preds.extend(y_hat.cpu().view(-1).tolist())

visuals_train_preds = np.expm1(visuals_train_preds)
visuals_val_preds = np.expm1(visuals_val_preds)
visuals_test_preds = np.expm1(visuals_test_preds)

Loading Training Predictions: 100%|██████████| 175/175 [00:40<00:00,  4.30it/s]
Loading Validation Predictions: 100%|██████████| 44/44 [00:10<00:00,  4.36it/s]
Loading Test Predictions: 100%|██████████| 55/55 [00:12<00:00,  4.44it/s]


## Setting Up Trailer Model

In [8]:
def numerical_sort_key(filename):
    return [int(text) if text.isdigit() else text.lower()
            for text in re.split(r'(\d+)', filename)]

class MovieKeyframeDataset(Dataset):
    def __init__(self, dataframe, frame_dir, image_processor, num_frames=8):
        self.dataframe = dataframe
        self.frame_dir = frame_dir
        self.image_processor = image_processor
        self.num_frames = num_frames

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        trailer_id = row['trailer']
        label = torch.tensor(row['log_revenue'], dtype=torch.float32)

        frame_folder = os.path.join(self.frame_dir, str(trailer_id))

        if not os.path.exists(frame_folder):

            dummy_frames = torch.zeros((self.num_frames, 3, 224, 224))
            return {
                "pixel_values": dummy_frames,
                "labels": label
            }

        frame_files = sorted([
            f for f in os.listdir(frame_folder) if f.endswith(".jpg")
        ], key=numerical_sort_key)

        selected_frames = frame_files[3:self.num_frames+3]

        frames = []
        for fname in selected_frames:
            img_path = os.path.join(frame_folder, fname)
            img = Image.open(img_path).convert("RGB")
            frames.append(np.array(img))

        pixel_values = self.image_processor(frames, return_tensors="pt")["pixel_values"][0]

        return {
            "pixel_values": pixel_values,
            "labels": label             
        }
    


class TimeSformer(nn.Module):
    def __init__(self, model_name="facebook/timesformer-base-finetuned-k400"):
        super().__init__()
        self.backbone = TimesformerForVideoClassification.from_pretrained(model_name)

        hidden_size = self.backbone.config.hidden_size 

        self.backbone.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )

    def forward(self, pixel_values): 
        outputs = self.backbone(pixel_values)
        return outputs.logits.view(-1) 

## Loading Data for Trailer Model

In [9]:
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")


df_train = pd.read_csv('movie_data_train.csv')
df_train['log_revenue'] = np.log1p(df_train['revenue'])

train_dataset, val_dataset = train_test_split(df_train, test_size=0.2, random_state=42)

train_dataset = MovieKeyframeDataset(train_dataset, "frames", processor, num_frames=8)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)

val_dataset = MovieKeyframeDataset(val_dataset, "frames", processor, num_frames=8)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

df_test = pd.read_csv('movie_data_test.csv')
df_test['log_revenue'] = np.log1p(df_test['revenue'])
test_dataset = MovieKeyframeDataset(df_test, "frames", processor, num_frames=8)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


## Get Predictions from Trailer Model

In [10]:
model = TimeSformer().to(device) 
checkpoint = torch.load('models/best_trailer_model.pt', map_location=device)
model.load_state_dict(checkpoint)
model.eval()

trailer_train_preds = []
trailer_val_preds = []
trailer_test_preds = []

with torch.no_grad():
    for batch in tqdm(train_loader, desc="Loading Training Predictions"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values)
        loss = criterion(outputs, labels)
        
        trailer_train_preds.extend(outputs.cpu().numpy())

    for batch in tqdm(val_loader, desc="Loading Validation Predictions"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values)
        loss = criterion(outputs, labels)
        
        trailer_val_preds.extend(outputs.cpu().numpy())

    for batch in tqdm(test_loader, desc="Loading Test Predictions"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values)
        loss = criterion(outputs, labels)
        
        trailer_test_preds.extend(outputs.cpu().numpy())


trailer_train_preds = np.expm1(trailer_train_preds)
trailer_val_preds = np.expm1(trailer_val_preds)
trailer_test_preds = np.expm1(trailer_test_preds)


Loading Training Predictions: 100%|██████████| 2786/2786 [05:11<00:00,  8.93it/s]
Loading Validation Predictions: 100%|██████████| 697/697 [01:20<00:00,  8.62it/s]
Loading Test Predictions: 100%|██████████| 871/871 [01:39<00:00,  8.79it/s]


## Prepare the Dataset for Generic Features Model

In [11]:
df = pd.read_csv('movie_data_train.csv')

numerical_features = ['budget', 'runtime', 'viewCount', 'likeCount', 'favoriteCount', 'commentCount']

df['release_month'] = pd.to_datetime(df['release_date']).dt.month

df['genres_list'] = df['genres'].str.split(',')

X_numeric = df[numerical_features + ['release_month']]
y = df['revenue']

genres_exploded = df['genres_list'].explode()
unique_genres = genres_exploded.dropna().unique()

for genre in unique_genres:
    df[f'genre_{genre.strip()}'] = df['genres_list'].apply(
        lambda x: 1 if x is not None and genre in [g.strip() for g in x] else 0
    )

genre_columns = [col for col in df.columns if col.startswith('genre_')]

X_combined = pd.concat([X_numeric, df[genre_columns]], axis=1)

X_train, X_val, y_train, y_val = train_test_split(
    X_combined, y, test_size=0.2, random_state=42
)

imputer = SimpleImputer(strategy='constant', fill_value=0)

X_train = imputer.fit_transform(X_train)
X_val = imputer.transform(X_val)

test_df = pd.read_csv('movie_data_test.csv')

test_df['release_month'] = pd.to_datetime(test_df['release_date']).dt.month
test_df['genres_list'] = test_df['genres'].str.split(',')

X_test_numeric = test_df[numerical_features + ['release_month']]

for genre in unique_genres:
    test_df[f'genre_{genre.strip()}'] = test_df['genres_list'].apply(
        lambda x: 1 if x is not None and genre in [g.strip() for g in x] else 0
    )

X_test = pd.concat([X_test_numeric, test_df[genre_columns]], axis=1)

print(f"Test data shape: {X_test.shape}")

X_test = SimpleImputer(strategy='constant', fill_value=0).fit_transform(X_test)

y_test = test_df['revenue']

Test data shape: (871, 26)


## Get Predictions from Generic Features Model

In [12]:
model = xgb.XGBRegressor(
            n_estimators=2000,
            learning_rate=0.015,
            max_leaves=10,
            subsample=0.5,
            colsample_bytree=0.6,
            reg_alpha=0.1,
            reg_lambda=0.1,
            random_state=42,
            min_child_weight=40,
            tree_method="hist", 
            verbosity=0
        )

model.fit(
    X_train, y_train,
    eval_set=[(X_val, y_val)],
    verbose=False
)

generic_train_preds = model.predict(X_train)
generic_val_preds = model.predict(X_val)
generic_test_preds = model.predict(X_test)

## Preparing Data for Late Fusion Model

In [22]:
title_train_preds = title_train_preds.reshape(-1, 1)
title_val_preds = title_val_preds.reshape(-1, 1)
title_test_preds = title_test_preds.reshape(-1, 1)

visuals_train_preds = visuals_train_preds.reshape(-1, 1)
visuals_val_preds = visuals_val_preds.reshape(-1, 1)
visuals_test_preds = visuals_test_preds.reshape(-1, 1)

trailer_train_preds = trailer_train_preds.reshape(-1, 1)
trailer_val_preds = trailer_val_preds.reshape(-1, 1)
trailer_test_preds = trailer_test_preds.reshape(-1, 1)

generic_train_preds = generic_train_preds.reshape(-1, 1)
generic_val_preds = generic_val_preds.reshape(-1, 1)
generic_test_preds = generic_test_preds.reshape(-1, 1)

train_actuals = train_actuals.reshape(-1,1)
val_actuals = val_actuals.reshape(-1,1)
test_actuals = test_actuals.reshape(-1,1)

train_actuals = torch.from_numpy(train_actuals).to(device)
val_actuals = torch.from_numpy(val_actuals).to(device)
test_actuals = torch.from_numpy(test_actuals).to(device)

train_preds = np.concatenate([title_train_preds, visuals_train_preds, trailer_train_preds, generic_train_preds], axis=1)
val_preds = np.concatenate([title_val_preds, visuals_val_preds, trailer_val_preds, generic_val_preds], axis=1)
test_preds = np.concatenate([title_test_preds, visuals_test_preds, trailer_test_preds, generic_test_preds], axis=1)

train_preds = torch.from_numpy(train_preds.astype(np.float32)).to(device)
val_preds = torch.from_numpy(val_preds.astype(np.float32)).to(device)
test_preds = torch.from_numpy(test_preds.astype(np.float32)).to(device)



TypeError: expected np.ndarray (got Tensor)

In [65]:
class GatingNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dims=[32, 16]):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims

        for i in range(len(dims) - 1):
            layers.extend([
                nn.Linear(dims[i], dims[i + 1]),
                nn.ReLU()
            ])
        layers.append(nn.Linear(dims[-1], input_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        logits = self.net(x)
        weights = F.softmax(logits, dim=1)
        return weights


class GatingFusionEnsemble:
    def __init__(self, lr=1e-3, epochs=100, logging=True):
        self.gating_net = None
        self.logging = logging
        self.epochs = epochs
        self.lr = lr

    def fit(self, X_train, y_train, X_val=None, y_val=None):
        input_dim = X_train.shape[1]
        self.gating_net = GatingNetwork(input_dim).to(X_train.device)
        optimizer = torch.optim.Adam(self.gating_net.parameters(), lr=self.lr)
        criterion = nn.MSELoss()

        for epoch in range(self.epochs):
            self.gating_net.train()
            optimizer.zero_grad()
            weights = self.gating_net(X_train)
            weighted_preds = torch.sum(X_train * weights, dim=1, keepdim=True)
            loss = criterion(weighted_preds, y_train)
            loss.backward()
            optimizer.step()

            if X_val is not None and y_val is not None:
                self.gating_net.eval()
                with torch.no_grad():
                    val_weights = self.gating_net(X_val)
                    val_preds = torch.sum(X_val * val_weights, dim=1, keepdim=True)
                    val_loss = criterion(val_preds, y_val)

                    val_r2 = r2_score(y_val.cpu().numpy(), val_preds.cpu().numpy())

                    if self.logging:
                        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f} | Val R²: {val_r2:.4f}")

        return self

    def predict(self, X):
        self.gating_net.eval()
        with torch.no_grad():
            weights = self.gating_net(X)
            return torch.sum(X * weights, dim=1, keepdim=True)

In [73]:
X_scaler = StandardScaler()
X_train = torch.tensor(X_scaler.fit_transform(train_preds.cpu().numpy()), dtype=torch.float32).to(device)
X_val = torch.tensor(X_scaler.transform(val_preds.cpu().numpy()), dtype=torch.float32).to(device)

y_scaler = StandardScaler()
y_train = torch.tensor(y_scaler.fit_transform(train_actuals.cpu().numpy()), dtype=torch.float32).to(device)
y_val = torch.tensor(y_scaler.transform(val_actuals.cpu().numpy()), dtype=torch.float32).to(device)

X_test = torch.tensor(X_scaler.transform(test_preds.cpu().numpy()), dtype=torch.float32).to(device)
y_test = torch.tensor(y_scaler.transform(test_actuals.cpu().numpy()), dtype=torch.float32).to(device)

gated_ensemble = GatingFusionEnsemble(lr = 1e-5, epochs=10)
gated_ensemble.fit(X_train, y_train, X_val, y_val)

with torch.no_grad():
    val_pred = gated_ensemble.predict(X_val)
    val_pred_np = val_pred.cpu().numpy()
    val_actuals_np = y_val.cpu().numpy()
    val_r2 = r2_score(val_actuals_np, val_pred_np)

    test_pred = gated_ensemble.predict(X_test)
    test_pred_np = test_pred.cpu().numpy()
    test_actuals_np = y_test.cpu().numpy()
    test_r2 = r2_score(test_actuals_np, test_pred_np)

    print(f"Validation R²: {val_r2:.4f}")
    print(f"Test R²: {test_r2:.4f}")



Epoch 0 | Train Loss: 0.9471 | Val Loss: 0.7622 | Val R²: 0.4538
Epoch 1 | Train Loss: 0.9470 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 2 | Train Loss: 0.9469 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 3 | Train Loss: 0.9468 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 4 | Train Loss: 0.9467 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 5 | Train Loss: 0.9467 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 6 | Train Loss: 0.9466 | Val Loss: 0.7623 | Val R²: 0.4538
Epoch 7 | Train Loss: 0.9465 | Val Loss: 0.7624 | Val R²: 0.4538
Epoch 8 | Train Loss: 0.9464 | Val Loss: 0.7624 | Val R²: 0.4538
Epoch 9 | Train Loss: 0.9463 | Val Loss: 0.7624 | Val R²: 0.4537
Validation R²: 0.4537
Test R²: 0.4352
