Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ TESTS = \
Constructor \
Field \
Static \
Invokevirtual
Invokevirtual \
Inherit \
Initializer

check: $(addprefix tests/,$(TESTS:=-result.out))

Expand Down
1 change: 1 addition & 0 deletions class-heap.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ void free_class_heap()
free(constant->info);
}
free(class_heap.class_info[i]->clazz->constant_pool.constant_pool);
free(class_heap.class_info[i]->clazz->info);

field_t *field = class_heap.class_info[i]->clazz->fields;
for (u2 j = 0; j < class_heap.class_info[i]->clazz->fields_count;
Expand Down
17 changes: 10 additions & 7 deletions classfile.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class_header_t get_class_header(FILE *class_file)
};
}

class_info_t get_class_info(FILE *class_file)
class_info_t *get_class_info(FILE *class_file)
{
class_info_t info = {
.access_flags = read_u2(class_file),
.this_class = read_u2(class_file),
.super_class = read_u2(class_file),
};
class_info_t *info = malloc(sizeof(class_info_t));
info->access_flags = read_u2(class_file);
info->this_class = read_u2(class_file);
info->super_class = read_u2(class_file);

u2 interfaces_count = read_u2(class_file);
assert(!interfaces_count && "This VM does not support interfaces.");
return info;
Expand Down Expand Up @@ -303,12 +303,15 @@ class_file_t get_class(FILE *class_file)
class_file_t clazz = {.constant_pool = get_constant_pool(class_file)};

/* Read information about the class that was compiled. */
get_class_info(class_file);
clazz.info = get_class_info(class_file);

/* Read the list of fields */
clazz.fields = get_fields(class_file, &clazz.constant_pool, &clazz);

/* Read the list of static methods */
clazz.methods = get_methods(class_file, &clazz.constant_pool);

clazz.initialized = false;

return clazz;
}
4 changes: 3 additions & 1 deletion classfile.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ typedef struct {

typedef struct {
constant_pool_t constant_pool;
class_info_t *info;
method_t *methods;
field_t *fields;
u2 fields_count;
bool initialized;
} class_file_t;

typedef struct {
Expand All @@ -67,7 +69,7 @@ typedef struct {
} meta_class_t;

class_header_t get_class_header(FILE *class_file);
class_info_t get_class_info(FILE *class_file);
class_info_t *get_class_info(FILE *class_file);
method_t *get_methods(FILE *class_file, constant_pool_t *cp);
void read_method_attributes(FILE *class_file,
method_info *info,
Expand Down
150 changes: 110 additions & 40 deletions jvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,28 @@ stack_entry_t *execute(method_t *method,

/* the method to be called */
char *method_name, *method_descriptor, *class_name;
class_name = find_method_info_from_index(index, clazz, &method_name,
&method_descriptor);
method_t *own_method = NULL;
class_file_t *target_class = NULL;

/* recursively find method from child to parent */
while (!own_method) {
if (!target_class)
class_name = find_method_info_from_index(
index, clazz, &method_name, &method_descriptor);
else
class_name = find_class_name_from_index(
target_class->info->super_class, target_class);
find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class &&
"Failed to load class in i_invokestatic");
own_method =
find_method(method_name, method_descriptor, target_class);
}

class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */
/* call static initialization. Only the class that contains this
* method should do static initialization */
if (!target_class->initialized) {
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -279,10 +295,7 @@ stack_entry_t *execute(method_t *method,
free(exec_res);
}
}
assert(target_class && "Failed to load class in invokestatic");

method_t *own_method =
find_method(method_name, method_descriptor, target_class);
uint16_t num_params = get_number_of_parameters(own_method);
local_variable_t own_locals[own_method->code.max_locals];
for (int i = num_params - 1; i >= 0; i--)
Expand Down Expand Up @@ -794,19 +807,33 @@ stack_entry_t *execute(method_t *method,
uint16_t index = ((param1 << 8) | param2);

char *field_name, *field_descriptor, *class_name;
field_t *field = NULL;
class_file_t *target_class = NULL;

class_name = find_field_info_from_index(index, clazz, &field_name,
&field_descriptor);

/* skip java.lang.System in order to support java print
* method */
if (!strcmp(class_name, "java/lang/System")) {
pc += 3;
break;
}

class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */
while (!field) {
if (target_class)
class_name = find_class_name_from_index(
target_class->info->super_class, target_class);

find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class && "Failed to load class in i_getstatic");

field = find_field(field_name, field_descriptor, target_class);
}

/* call static initialization. Only the class that contains this
* field should do static initialization */
if (!target_class->initialized) {
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -818,9 +845,6 @@ stack_entry_t *execute(method_t *method,
}
}

field_t *field =
find_field(field_name, field_descriptor, target_class);

switch (field_descriptor[0]) {
case 'B':
/* signed byte */
Expand Down Expand Up @@ -866,12 +890,33 @@ stack_entry_t *execute(method_t *method,
uint16_t index = ((param1 << 8) | param2);

char *field_name, *field_descriptor, *class_name;
field_t *field = NULL;
class_file_t *target_class = NULL;
class_name = find_field_info_from_index(index, clazz, &field_name,
&field_descriptor);

class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */
/* skip java.lang.System in order to support java print
* method */
if (!strcmp(class_name, "java/lang/System")) {
pc += 3;
break;
}

while (!field) {
if (target_class)
class_name = find_class_name_from_index(
target_class->info->super_class, target_class);

find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class && "Failed to load class in i_putstatic");

field = find_field(field_name, field_descriptor, target_class);
}

/* call static initialization. Only the class that contains this
* field should do static initialization */
if (!target_class->initialized) {
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -882,8 +927,6 @@ stack_entry_t *execute(method_t *method,
free(exec_res);
}
}
field_t *field =
find_field(field_name, field_descriptor, target_class);

switch (field_descriptor[0]) {
case 'B':
Expand Down Expand Up @@ -938,6 +981,9 @@ stack_entry_t *execute(method_t *method,

/* the method to be called */
char *method_name, *method_descriptor, *class_name;
class_file_t *target_class = NULL;
method_t *method = NULL;

class_name = find_method_info_from_index(index, clazz, &method_name,
&method_descriptor);

Expand All @@ -964,9 +1010,22 @@ stack_entry_t *execute(method_t *method,
}

/* FIXME: consider method modifier */
class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */
/* recursively find method from child to parent */
while (!method) {
if (target_class)
class_name = find_class_name_from_index(
target_class->info->super_class, target_class);
find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class &&
"Failed to load class in i_invokevirtual");
method =
find_method(method_name, method_descriptor, target_class);
}

/* call static initialization. Only the class that contains this
* method should do static initialization */
if (!target_class->initialized) {
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -977,8 +1036,7 @@ stack_entry_t *execute(method_t *method,
free(exec_res);
}
}
method_t *method =
find_method(method_name, method_descriptor, target_class);

uint16_t num_params = get_number_of_parameters(method);
local_variable_t own_locals[method->code.max_locals];
memset(own_locals, 0, sizeof(own_locals));
Expand Down Expand Up @@ -1138,8 +1196,26 @@ stack_entry_t *execute(method_t *method,

char *class_name = find_class_name_from_index(index, clazz);
class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */

/* FIXME: use linked list to prevent wasted space */
class_file_t **stack = malloc(sizeof(class_file_t *) * 100);
size_t count = 0;
while (true) {
find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class && "Failed to load class in i_new");
stack[count++] = target_class;
class_name = find_class_name_from_index(
target_class->info->super_class, target_class);
if (!strcmp(class_name, "java/lang/Object"))
break;
}

/* call static initialization */
while (count) {
target_class = stack[--count];
if (target_class->initialized)
continue;
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -1150,7 +1226,7 @@ stack_entry_t *execute(method_t *method,
free(exec_res);
}
}
assert(target_class && "Failed to load class in new");
free(stack);

object_t *object = create_object(target_class);
push_ref(op_stack, object);
Expand Down Expand Up @@ -1178,8 +1254,12 @@ stack_entry_t *execute(method_t *method,
}

class_file_t *target_class;
if (find_or_add_class_to_heap(class_name, prefix, &target_class)) {
/* Call static initialization */
find_or_add_class_to_heap(class_name, prefix, &target_class);
assert(target_class && "Failed to load class in i_invokespecial");

/* call static initialization */
if (!target_class->initialized) {
target_class->initialized = true;
method_t *method = find_method("<clinit>", "()V", target_class);
if (method) {
local_variable_t own_locals[method->code.max_locals];
Expand All @@ -1190,7 +1270,6 @@ stack_entry_t *execute(method_t *method,
free(exec_res);
}
}
assert(target_class && "Failed to load class in i_invokespecial");

/* find constructor method from class */
method_t *constructor =
Expand Down Expand Up @@ -1254,15 +1333,6 @@ int main(int argc, char *argv[])
prefix[match - argv[1] + 1] = '\0';
}

method_t *method = find_method("<clinit>", "()V", clazz);
if (method) {
local_variable_t own_locals[method->code.max_locals];
stack_entry_t *exec_res = execute(method, own_locals, clazz);
assert(exec_res->type == STACK_ENTRY_NONE &&
"<clinit> must not return a value");
free(exec_res);
}

/* execute the main method if found */
method_t *main_method =
find_method("main", "([Ljava/lang/String;)V", clazz);
Expand Down
54 changes: 54 additions & 0 deletions tests/Inherit.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
public class Inherit {
static int x = 1;

public static void static_call() {
System.out.println(1);
}

public void virtual_call() {
System.out.println(2);
}

public static void main(String[] args) {
Inherit obj = new Inherit();
InheritA objA = new InheritA();
InheritB objB = new InheritB();

/* check shared static fields */
System.out.println(obj.x);
System.out.println(objA.x);
System.out.println(objB.x);
obj.x = 2;
System.out.println(obj.x);
System.out.println(objA.x);
System.out.println(objB.x);
objA.x = 3;
System.out.println(obj.x);
System.out.println(objA.x);
System.out.println(objB.x);

/* check static methods inheritance (compiler will replace objects with classes) */
obj.static_call();
objA.static_call();
objB.static_call();

/* check virtual methods inheritance */
obj.virtual_call();
objA.virtual_call();
objB.virtual_call();
}
}

class InheritA extends Inherit {

}

class InheritB extends InheritA {
/* check override */
public void virtual_call() {
System.out.println(3);
}
public static void static_call() {
System.out.println(4);
}
}
Loading