In [None]:
import os
import bisect
import time
import random
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight 
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_fscore_support
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from prince import FAMD, PCA, MCA

# Preprocessing

In [None]:
# MICE + Feature selection + MCA (Individual)
mca = MCA(n_components=16, random_state=0)
ccba = pd.read_csv('data/concat_ccba.csv').drop(columns=['clamt', 'csamt', 'cucah'])
cdtx = pd.read_csv('data/concat_cdtx0001.csv').drop(columns=['country', 'cur_type'])
custinfo = pd.read_csv('data/MICED_concat_custinfo.csv').astype({'occupation_code': str})
custinfo = pd.concat([custinfo.drop(columns=['occupation_code']), mca.fit_transform(pd.DataFrame(custinfo['occupation_code']))], axis=1)
dp = pd.read_csv('data/MICED_concat_dp.csv').astype({'tx_type': str, 'info_asset_code': str}).replace({'CR': 0, 'DB': 1})
dp = pd.concat([dp.drop(columns=['tx_type', 'info_asset_code', 'fiscTxId', 'txbranch']), mca.fit_transform(pd.DataFrame(dp['tx_type'])), mca.fit_transform(pd.DataFrame(dp['info_asset_code']))], axis=1)
dp['tx_amt'] *= dp['exchg_rate']
remit = pd.read_csv('data/concat_remit1.csv').drop(columns=['trans_no'])
public_alert = pd.read_csv('data/public_x_alert_date.csv')
private_alert = pd.read_csv('data/private_x_alert_date.csv')
# train_alert = pd.read_csv('data/train_x_alert_date.csv')
# train_answer = pd.read_csv('data/train_y_answer.csv')
train_alert = pd.concat([pd.read_csv('data/train_x_alert_date.csv'), public_alert], ignore_index=True)
train_answer = pd.concat([pd.read_csv('data/train_y_answer.csv'), pd.read_csv('data/24_ESun_public_y_answer.csv')], ignore_index=True)
sample_submission = pd.read_csv('sample_submission.csv')

scaler = StandardScaler()
cdtx = pd.concat([cdtx[['cust_id', 'date']], pd.DataFrame(scaler.fit_transform(cdtx.drop(columns=['cust_id', 'date'])))], axis=1)
custinfo = pd.concat([custinfo[['cust_id', 'alert_key']], pd.DataFrame(scaler.fit_transform(custinfo.drop(columns=['cust_id', 'alert_key'])))], axis=1)
dp = pd.concat([dp[['cust_id', 'tx_date']], pd.DataFrame(scaler.fit_transform(dp.drop(columns=['cust_id', 'tx_date', 'exchg_rate'])))], axis=1)
remit = pd.concat([remit[['cust_id', 'trans_date']], pd.DataFrame(scaler.fit_transform(remit.drop(columns=['cust_id', 'trans_date'])))], axis=1)
ccba = pd.concat([ccba[['cust_id', 'byymm']], pd.DataFrame(scaler.fit_transform(ccba.drop(columns=['cust_id', 'byymm'])))], axis=1)

In [None]:
all_data, all_label = [], []
window, max_len = 30, 64
n = len(train_alert)
l1, l2, l3, l4 = 0, 0, 0, 0
for i, row in train_alert.iterrows():
    alert_key, date = row['alert_key'], row['date']
    info = custinfo[custinfo['alert_key']==alert_key]
    cust_id = info['cust_id'].iloc[0]

    all_data.append({
        'info': torch.tensor(info.drop(columns=['cust_id', 'alert_key']).to_numpy('float32')),
        'ccba': torch.tensor(ccba.query(f'`cust_id` == "{cust_id}" and `byymm` <= {date}').sort_values('byymm').drop(columns=['cust_id', 'byymm']).to_numpy('float32')[-max_len:]),
        'cdtx': torch.tensor(cdtx.query(f'`cust_id` == "{cust_id}" and {date - window} <= `date` and `date` <= {date}').sort_values('date').drop(columns=['cust_id', 'date']).to_numpy('float32')[-max_len:]),
        'dp': torch.tensor(dp.query(f'`cust_id` == "{cust_id}" and {date - window} <= `tx_date` and `tx_date` <= {date}').sort_values('tx_date').drop(columns=['cust_id', 'tx_date']).to_numpy('float32')[-max_len:]),
        'remit': torch.tensor(remit.query(f'`cust_id` == "{cust_id}" and {date - window} <= `trans_date` and `trans_date` <= {date}').sort_values('trans_date').drop(columns=['cust_id', 'trans_date']).to_numpy('float32')[-max_len:]),
    })
    all_label.append(torch.tensor(train_answer[train_answer['alert_key']==row['alert_key']]['sar_flag'].iloc[0]))

    all_data[-1]['length'] = torch.tensor([len(all_data[-1]['ccba']), len(all_data[-1]['cdtx']), len(all_data[-1]['dp']), len(all_data[-1]['remit'])])
    l1 = max(l1, all_data[-1]['length'][0])
    l2 = max(l2, all_data[-1]['length'][1])
    l3 = max(l3, all_data[-1]['length'][2])
    l4 = max(l4, all_data[-1]['length'][3])

    if i % 1000 == 0:
        print(i, n, l1, l2, l3, l4)

input_size = [all_data[-1]['info'].shape[-1], all_data[-1]['ccba'].shape[-1], all_data[-1]['cdtx'].shape[-1], all_data[-1]['dp'].shape[-1], all_data[-1]['remit'].shape[-1]]
print(input_size)
for i in range(len(all_data)):
    all_data[i] = {
        'info': all_data[i]['info'],
        'ccba': torch.cat((all_data[i]['ccba'], torch.zeros(l1 - all_data[i]['length'][0], input_size[1]))),
        'cdtx': torch.cat((all_data[i]['cdtx'], torch.zeros(l2 - all_data[i]['length'][1], input_size[2]))),
        'dp': torch.cat((all_data[i]['dp'], torch.zeros(l3 - all_data[i]['length'][2], input_size[3]))),
        'remit': torch.cat((all_data[i]['remit'], torch.zeros(l4 - all_data[i]['length'][3], input_size[4]))),
        'length': all_data[i]['length'],
    }
    
    if i % 1000 == 0:
        print(i, n)

torch.save(all_data, f'data_pt/all_data_{window}_{max_len}.pt')
torch.save(all_label, f'data_pt/all_label_{window}_{max_len}.pt')

In [None]:
test_data = []
n = len(public_alert) + len(private_alert)
l1, l2, l3, l4 = 0, 0, 0, 0
for i, row in pd.concat([public_alert, private_alert], ignore_index=True).iterrows():
    alert_key, date = row['alert_key'], row['date']
    info = custinfo[custinfo['alert_key']==alert_key]
    cust_id = info['cust_id'].iloc[0]

    test_data.append({
        'alert_key': alert_key,
        'info': torch.tensor(info.drop(columns=['cust_id', 'alert_key']).to_numpy('float32')),
        'ccba': torch.tensor(ccba.query(f'`cust_id` == "{cust_id}" and `byymm` <= {date}').sort_values('byymm').drop(columns=['cust_id', 'byymm']).to_numpy('float32')[-max_len:]),
        'cdtx': torch.tensor(cdtx.query(f'`cust_id` == "{cust_id}" and {date - window} <= `date` and `date` <= {date}').sort_values('date').drop(columns=['cust_id', 'date']).to_numpy('float32')[-max_len:]),
        'dp': torch.tensor(dp.query(f'`cust_id` == "{cust_id}" and {date - window} <= `tx_date` and `tx_date` <= {date}').sort_values('tx_date').drop(columns=['cust_id', 'tx_date']).to_numpy('float32')[-max_len:]),
        'remit': torch.tensor(remit.query(f'`cust_id` == "{cust_id}" and {date - window} <= `trans_date` and `trans_date` <= {date}').sort_values('trans_date').drop(columns=['cust_id', 'trans_date']).to_numpy('float32')[-max_len:]),
    })

    test_data[-1]['length'] = torch.tensor([len(test_data[-1]['ccba']), len(test_data[-1]['cdtx']), len(test_data[-1]['dp']), len(test_data[-1]['remit'])])
    l1 = max(l1, test_data[-1]['length'][0])
    l2 = max(l2, test_data[-1]['length'][1])
    l3 = max(l3, test_data[-1]['length'][2])
    l4 = max(l4, test_data[-1]['length'][3])

    if i % 1000 == 0:
        print(i, n, l1, l2, l3, l4)

input_size = [test_data[-1]['info'].shape[-1], test_data[-1]['ccba'].shape[-1], test_data[-1]['cdtx'].shape[-1], test_data[-1]['dp'].shape[-1], test_data[-1]['remit'].shape[-1]]
print(input_size)
for i in range(len(test_data)):
    test_data[i] = {
        'alert_key': test_data[i]['alert_key'],
        'info': test_data[i]['info'],
        'ccba': torch.cat((test_data[i]['ccba'], torch.zeros(l1 - test_data[i]['length'][0], input_size[1]))),
        'cdtx': torch.cat((test_data[i]['cdtx'], torch.zeros(l2 - test_data[i]['length'][1], input_size[2]))),
        'dp': torch.cat((test_data[i]['dp'], torch.zeros(l3 - test_data[i]['length'][2], input_size[3]))),
        'remit': torch.cat((test_data[i]['remit'], torch.zeros(l4 - test_data[i]['length'][3], input_size[4]))),
        'length': test_data[i]['length'],
    }
    
    if i % 1000 == 0:
        print(i, n)

torch.save(test_data, f'data_pt/test_data_{window}_{max_len}.pt')

# Training

In [None]:
class SARDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return (self.data[index], self.label[index])

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

    def forward(self, out, h_out):
        score = torch.bmm(out, h_out.unsqueeze(2)).squeeze(2)
        attention_weights = F.softmax(score, dim=1)
        context_vector = torch.bmm(out.permute(0, 2, 1), attention_weights.unsqueeze(2)).squeeze(2)

        return context_vector, attention_weights
        
class SARModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional, attention, num_class):
        super(SARModel, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.attention = attention
        self.num_class = num_class

        self.fc1 = nn.Sequential(
            nn.Linear(input_size[0], hidden_size[0]),
            nn.BatchNorm1d(hidden_size[0]),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size[0], hidden_size[0]),
            nn.BatchNorm1d(hidden_size[0]),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size[0], hidden_size[0]),
            nn.BatchNorm1d(hidden_size[0]),
            nn.ReLU(),
        )
        self.gru1 = nn.GRU(input_size=input_size[1], hidden_size=hidden_size[1], num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        self.gru2 = nn.GRU(input_size=input_size[2], hidden_size=hidden_size[2], num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        self.gru3 = nn.GRU(input_size=input_size[3], hidden_size=hidden_size[3], num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        self.gru4 = nn.GRU(input_size=input_size[4], hidden_size=hidden_size[4], num_layers=num_layers, dropout=dropout, batch_first=True, bidirectional=bidirectional)
        if attention:
           self.attn1 = Attention()
           self.attn2 = Attention()
           self.attn3 = Attention()
           self.attn4 = Attention()
        self.fc2 = nn.Sequential(
            nn.Linear(self.encoder_output_size, num_class),
            # nn.BatchNorm1d(hidden_size[0] * 2),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            # nn.Linear(hidden_size[0] * 2, hidden_size[0]),
            # nn.BatchNorm1d(hidden_size[0]),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            # nn.Linear(hidden_size[0], num_class),
        )

    @property
    def encoder_output_size(self):
        return self.hidden_size[0] + sum(self.hidden_size[1:]) * (int(self.bidirectional) + 1) * (int(self.attention) + 1)

    def forward(self, data_batch):
        out = self.fc1(data_batch['info'][:, 0, :])
        x = nn.utils.rnn.pack_padded_sequence(data_batch['ccba'], torch.maximum(torch.tensor(1), data_batch['length'][:, 0].cpu()), batch_first=True, enforce_sorted=False)
        out1, h1 = self.gru1(x)
        h_out1 = torch.cat((h1[-2], h1[-1]), dim=1) if self.bidirectional else h1[-1]
        x = nn.utils.rnn.pack_padded_sequence(data_batch['cdtx'], torch.maximum(torch.tensor(1), data_batch['length'][:, 1].cpu()), batch_first=True, enforce_sorted=False)
        out2, h2 = self.gru2(x)
        h_out2 = torch.cat((h2[-2], h2[-1]), dim=1) if self.bidirectional else h2[-1]
        x = nn.utils.rnn.pack_padded_sequence(data_batch['dp'], torch.maximum(torch.tensor(1), data_batch['length'][:, 2].cpu()), batch_first=True, enforce_sorted=False)
        out3, h3 = self.gru3(x)
        h_out3 = torch.cat((h3[-2], h3[-1]), dim=1) if self.bidirectional else h3[-1]
        x = nn.utils.rnn.pack_padded_sequence(data_batch['remit'], torch.maximum(torch.tensor(1), data_batch['length'][:, 3].cpu()), batch_first=True, enforce_sorted=False)
        out4, h4 = self.gru4(x)
        h_out4 = torch.cat((h4[-2], h4[-1]), dim=1) if self.bidirectional else h4[-1]
        
        if self.attention:
            out1, _ = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
            context_vector1, _ = self.attn1(out1, h_out1)
            h_out1 = torch.cat((h_out1, context_vector1), dim=1)
            out2, _ = nn.utils.rnn.pad_packed_sequence(out2, batch_first=True)
            context_vector2, _ = self.attn2(out2, h_out2)
            h_out2 = torch.cat((h_out2, context_vector2), dim=1)
            out3, _ = nn.utils.rnn.pad_packed_sequence(out3, batch_first=True)
            context_vector3, _ = self.attn3(out3, h_out3)
            h_out3 = torch.cat((h_out3, context_vector3), dim=1)
            out4, _ = nn.utils.rnn.pad_packed_sequence(out4, batch_first=True)
            context_vector4, _ = self.attn4(out4, h_out4)
            h_out4 = torch.cat((h_out4, context_vector4), dim=1)
        out = self.fc2(torch.cat((out, h_out1, h_out2, h_out3, h_out4), dim=1))
        return out

class FocalLoss(nn.Module):
    def __init__(self, num_classes=2, alpha=0.25, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.num_classes = num_classes
        if torch.is_tensor(alpha):
            self.alpha = alpha
        elif isinstance(alpha, (float, int)): 
            self.alpha = torch.tensor([1 - alpha, alpha])
        else:
            self.alpha = torch.ones(num_classes)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        inputs_softmax = F.softmax(inputs, dim=1)
        self.alpha = self.alpha.to(inputs.device)
        loss = -self.alpha * torch.pow(1 - inputs_softmax, self.gamma) * torch.log(inputs_softmax)
        targets = nn.functional.one_hot(targets, num_classes=self.num_classes)
        loss = torch.sum(targets * loss, dim=1)

        if self.reduction == 'none':
            return loss
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss.mean()

In [None]:
rand_seed = 0
random.seed(rand_seed)
np.random.seed(rand_seed)  
torch.manual_seed(rand_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(rand_seed)
    torch.cuda.manual_seed_all(rand_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

device = 'cuda'
if not torch.cuda.is_available():
    device = 'cpu'
print(f'Using {device} for training.')

window = 30
max_len = 64
num_workers = 0
batch_size = 64
lr = 1e-3
weight_decay = 1e-2
num_epoch = 100
early_stop = 20

hidden_size = [32, 32, 32, 32, 32]
num_layers = 3
dropout = 0.1
bidirectional = False
attention = False

In [None]:
all_data = torch.load(f'data_pt/all_data_{window}_{max_len}.pt')
all_label = torch.load(f'data_pt/all_label_{window}_{max_len}.pt')

In [None]:
def f1_loss(y_pred, y_true, eps=1e-7):
    tp = torch.sum(y_true * y_pred, 0)
    tn = torch.sum((1 - y_true) * (1 - y_pred), 0)
    fp = torch.sum((1 - y_true) * y_pred, 0)
    fn = torch.sum(y_true * (1 - y_pred), 0)

    p = tp / (tp + fp + eps)
    r = tp / (tp + fn + eps)

    f1 = 2 * p * r / (p + r + eps)
    f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)
    return 1 - torch.mean(f1)

def score(y_true, y_prob):
    count = 0
    for i, v in enumerate(sorted(y_prob)):
        if v[1] == 1:
            count += 1
            if count == 2:
                return (y_true.count(1) - 1) / (len(y_prob) - i)

train_data, vali_data, train_label, vali_label = train_test_split(all_data, all_label, test_size=0.2, random_state=rand_seed, stratify=all_label)
train_dataset = SARDataset(train_data, train_label)
vali_dataset = SARDataset(vali_data, vali_label)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
vali_dataloader = DataLoader(dataset=vali_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

model = SARModel(
    input_size=[all_data[0]['info'].shape[-1], all_data[0]['ccba'].shape[-1], all_data[0]['cdtx'].shape[-1], all_data[0]['dp'].shape[-1], all_data[0]['remit'].shape[-1]],
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    attention=attention,
    num_class=2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.0)
# criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(compute_class_weight(class_weight='balanced', classes=np.unique(all_label), y=[v.item() for v in all_label])).to(device))
focal_loss = FocalLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

total_time, max_score, count = 0, 0, 0
for epoch in range(num_epoch):
    start = time.time()

    total_loss, total_acc = 0, 0
    y_true_t, y_pred_t, y_prob_t = [], [], []
    model.train()
    for i, (data_batch, data_label) in enumerate(train_dataloader):
        data_batch, data_label = {k: v.to(device) for k, v in data_batch.items()}, data_label.to(device)
        pred = model(data_batch)
        loss = criterion(pred, data_label)
        # loss = focal_loss(pred, data_label)
        # loss = focal_loss(pred, data_label) + f1_loss(F.softmax(pred, dim=1)[:, 1], data_label)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        total_loss += loss.detach().item()
        total_acc += (pred.argmax(1) == data_label).sum().item()
        y_true_t += data_label.tolist()
        y_pred_t += pred.argmax(1).tolist()
        y_prob_t += torch.cat((F.softmax(pred, dim=1)[:, 1].unsqueeze(1), data_label.unsqueeze(1)), 1).tolist()
    loss_t, acc_t = total_loss / len(train_dataloader), total_acc / len(train_dataloader.dataset)
    p_t, r_t, f_t, _ = precision_recall_fscore_support(y_true_t, y_pred_t, average='binary', zero_division=0)
    s_t = score(y_true_t, y_prob_t)
        
    total_loss, total_acc = 0, 0
    y_true_v, y_pred_v, y_prob_v = [], [], []
    model.eval()
    with torch.no_grad():
        for data_batch, data_label in vali_dataloader:
            data_batch, data_label = {k:v.to(device) for k, v in data_batch.items()}, data_label.to(device)
            pred = model(data_batch)
            loss = criterion(pred, data_label)
            # loss = focal_loss(pred, data_label)
            # loss = focal_loss(pred, data_label) + f1_loss(F.softmax(pred, dim=1)[:, 1], data_label)

            total_loss += loss.detach().item()
            total_acc += (pred.argmax(1) == data_label).sum().item()
            y_true_v += data_label.tolist()
            y_pred_v += pred.argmax(1).tolist()
            y_prob_v += torch.cat((F.softmax(pred, dim=1)[:, 1].unsqueeze(1), data_label.unsqueeze(1)), 1).tolist()
    loss_v, acc_v = total_loss / len(vali_dataloader), total_acc / len(vali_dataloader.dataset)
    p_v, r_v, f_v, _ = precision_recall_fscore_support(y_true_v, y_pred_v, average='binary', zero_division=0)
    s_v = score(y_true_v, y_prob_v)

    total_time += time.time() - start
    print('Epoch {:3d}/{} Train L/A/P/R/F/S {:.6f}/{:.6f}/{:.6f}/{:.6f}/{:.6f}/{:.6f} Vali L/A/P/R/F/S {:.6f}/{:.6f}/{:.6f}/{:.6f}/{:.6f}/{:.6f} Time {:.6f}s'.format(epoch + 1, num_epoch, loss_t, acc_t, p_t, r_t, f_t, s_t, loss_v, acc_v, p_v, r_v, f_v, s_v, time.time() - start))
        
    if s_v > max_score:
        count = 0
        max_score = s_v
        torch.save(model.state_dict(), 'best3.pt')
    else:
        count += 1
        if count >= early_stop:
            break

    if scheduler.get_last_lr()[0] > lr / 10:
        scheduler.step()
print(f'Training Time: {total_time}s')

# Test

In [None]:
test_data = torch.load(f'data_pt/test_data_{window}_{max_len}.pt')
test_dataset = SARDataset(test_data, [0]*len(test_data))
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

model = SARModel(
    input_size=[test_data[0]['info'].shape[-1], test_data[0]['ccba'].shape[-1], test_data[0]['cdtx'].shape[-1], test_data[0]['dp'].shape[-1], test_data[0]['remit'].shape[-1]],
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    attention=attention,
    num_class=2
)
ckpt = torch.load('best3.pt')
model.load_state_dict(ckpt)
model.to(device)
    
prediction = {}
model.eval()
with torch.no_grad():
    for data_batch, _ in test_dataloader:
        alert_key = data_batch['alert_key'].cpu().tolist()[0]
        data_batch = {k:v.to(device) for k, v in data_batch.items() if k != 'alert_key'}
        pred = model(data_batch)
        prediction[alert_key] = F.softmax(pred, dim=1)[0][1].cpu().tolist()
print(sorted(prediction.items(), key=lambda x:x[1])[::-1])

df_answer = pd.read_csv('data/24_ESun_public_y_answer.csv')
answer = set(df_answer[df_answer['sar_flag']==1]['alert_key'])
count, count2 = 0, 0
for i, (k, p) in enumerate(sorted(prediction.items(), key=lambda x:x[1])[::-1]):
    if k in df_answer['alert_key'].values:
        count2 += 1
        if k in answer:
            count += 1
        if count == len(answer) - 1:
            print('Public Score:', count / count2)
            break

sample_submission = pd.read_csv('sample_submission.csv')
max_pred = max(prediction.values())
for i, row in sample_submission.iterrows():
    if row['alert_key'] in df_answer['alert_key'].values:
        prediction[row['alert_key']] = df_answer[df_answer['alert_key']==row['alert_key']]['sar_flag'].values[0]
    else:
        prediction[row['alert_key']] = prediction[row['alert_key']] / max_pred

with open('submission.csv', 'w') as f:
    f.write('alert_key,probability\n')
    for k, p in sorted(prediction.items(), key=lambda x:x[1])[::-1]:
        f.write(f'{int(k)},{p}\n')