# Q3: SVM Classification â€” Customer Churn
Dataset: `svm_churn_dataset.csv`

In [None]:
# Common imports used across notebooks
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
RANDOM_STATE = 42


In [None]:
df = pd.read_csv('/mnt/data/aiml/svm_churn_dataset.csv')
df.head()

In [None]:
# Preprocess
df = df.copy()
if 'customer_id' in df.columns:
    df.drop(columns=['customer_id'], inplace=True)
df['churn'] = df['churn'].map({'Yes':1,'No':0}).fillna(df['churn'])
print(df['churn'].value_counts())

In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split

target = 'churn'
y = df[target]
X = df.drop(columns=[target])

num_cols = X.select_dtypes(include=['int64','float64']).columns.tolist()
cat_cols = X.select_dtypes(include=['object','category']).columns.tolist()

num_transform = Pipeline([('imputer', SimpleImputer(strategy='median')), ('scaler', StandardScaler())])
cat_transform = Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('ohe', OneHotEncoder(handle_unknown='ignore', drop='first'))])

pre = ColumnTransformer([('num', num_transform, num_cols), ('cat', cat_transform, cat_cols)])

pipe = Pipeline([('pre', pre), ('svc', SVC(probability=True, class_weight='balanced'))])

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=RANDOM_STATE)

param_grid = {'svc__C':[0.1,1,10], 'svc__gamma':['scale','auto'], 'svc__kernel':['rbf','linear']}
gs = GridSearchCV(pipe, param_grid, cv=4, scoring='f1', n_jobs=-1)
gs.fit(X_train, y_train)
print('Best params:', gs.best_params_)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
y_pred = gs.predict(X_test)
print(classification_report(y_test, y_pred))
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt='d'); plt.show()